{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bf83f1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization] sentence_transformers datasets lancedb -qU"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06b70433",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# How to use Lancedb with NER semantic search \\[for RAG\\] [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/LLM_Workflows/NER_Example/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/LLM_Workflows/NER_Example/notebook.ipynb)\n",
    "\n",
    "In this post we’ll walk through an example pipeline written in Hamilton to embed some text, and also capture extra metadata about the text that can be used when deciding what data to pull for RAG. This is a form of \"semantic search\" and we use LanceDB to store our data and query over it.\n",
    "\n",
    "Why capture, or rather extract (as you’ll see), extra metadata? Because you can use it to filter results to improve accuracy. You’ll need more than just cosine similarity to achieve a quality system [\\[1\\]](https://jxnl.co/writing/2024/05/11/low-hanging-fruit-for-rag-search/). [Named Entity Recognition (NER)](https://en.wikipedia.org/wiki/Named-entity_recognition) is just one approach to gather extra metadata from text that can be used for this purpose.\n",
    "\n",
    "> In short, we use the NER model to further filter the semantic search results. The predicted named entities are used as  “filters” (pre or post) to filter the vector search results. This is particularly helpful if you want to restrict the search to records that contain information about the named entities that are also found within the query.\n",
    "\n",
    "In this notebook we'll build out a processing pipeline and walkthrough the code to:\n",
    "\n",
    "1. Extract named entities from text.\n",
    "2. Store them in a LanceDB as metadata (alongside embedding vectors).\n",
    "3. We extract named entities from incoming queries and use them to filter and search only through records containing these named entities.\n",
    "\n",
    "\n",
    "Let's get started."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0aa34fb4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-10T22:52:53.305913Z",
     "start_time": "2024-04-10T22:52:50.668455Z"
    },
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# load the jupyter magic\n",
    "%load_ext hamilton.plugins.jupyter_magic"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f6a5a66",
   "metadata": {},
   "source": [
    "# Load the data\n",
    "Here we first start by loading the dataset from huggingface.\n",
    "\n",
    "Here we use a DataLoader that Hamilton comes with to load the dataset for us. We could do some filtering within the loading function, but instead choose to break it out into another function to sample and augment the loaded dataset.\n",
    "\n",
    "Note about the sampling below, in real life we’d use the full data set. We sample here to make this example tractable to run. Otherwise we modify the data set in the following way:\n",
    "1. We remove documents with empty titles and text.\n",
    "2. We truncate text to only be the first 1000 characters. This is to limit the dataset size, but to also make it fit into our the context window that creates our embeddings. In real life you’d probably want to process the entire text somehow, or create separate embeddings for different text chunks, etc.\n",
    "3. Further to simplify things, we combine the title & text into a single field for NER & embedding purposes. We assume the title and the first 1000 characters of text contain enough information to get a general gist of the document to create an embedding and get relevant entities out. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "622dd8e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"881pt\" height=\"324pt\"\n",
       " viewBox=\"0.00 0.00 880.65 324.30\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 320.3)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-320.3 876.65,-320.3 876.65,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"81.25,-178.3 81.25,-308.3 166.1,-308.3 166.1,-178.3 81.25,-178.3\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.68\" y=\"-291\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- sampled_articles -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>sampled_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M860.65,-121.1C860.65,-121.1 752.8,-121.1 752.8,-121.1 746.8,-121.1 740.8,-115.1 740.8,-109.1 740.8,-109.1 740.8,-69.5 740.8,-69.5 740.8,-63.5 746.8,-57.5 752.8,-57.5 752.8,-57.5 860.65,-57.5 860.65,-57.5 866.65,-57.5 872.65,-63.5 872.65,-69.5 872.65,-69.5 872.65,-109.1 872.65,-109.1 872.65,-115.1 866.65,-121.1 860.65,-121.1\"/>\n",
       "<text text-anchor=\"start\" x=\"751.6\" y=\"-98\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">sampled_articles</text>\n",
       "<text text-anchor=\"start\" x=\"783.48\" y=\"-70\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>medium_articles.load_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M235.35,-168.1C235.35,-168.1 12,-168.1 12,-168.1 6,-168.1 0,-162.1 0,-156.1 0,-156.1 0,-116.5 0,-116.5 0,-110.5 6,-104.5 12,-104.5 12,-104.5 235.35,-104.5 235.35,-104.5 241.35,-104.5 247.35,-110.5 247.35,-116.5 247.35,-116.5 247.35,-156.1 247.35,-156.1 247.35,-162.1 241.35,-168.1 235.35,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"10.8\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.load_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"106.8\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Tuple</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>medium_articles.select_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M522.2,-168.1C522.2,-168.1 288.35,-168.1 288.35,-168.1 282.35,-168.1 276.35,-162.1 276.35,-156.1 276.35,-156.1 276.35,-116.5 276.35,-116.5 276.35,-110.5 282.35,-104.5 288.35,-104.5 288.35,-104.5 522.2,-104.5 522.2,-104.5 528.2,-104.5 534.2,-110.5 534.2,-116.5 534.2,-116.5 534.2,-156.1 534.2,-156.1 534.2,-162.1 528.2,-168.1 522.2,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"287.15\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.select_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"382.03\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M247.72,-136.3C253.29,-136.3 258.89,-136.3 264.49,-136.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"264.47,-139.8 274.47,-136.3 264.47,-132.8 264.47,-139.8\"/>\n",
       "</g>\n",
       "<!-- medium_articles -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>medium_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M690.3,-168.1C690.3,-168.1 584.7,-168.1 584.7,-168.1 578.7,-168.1 572.7,-162.1 572.7,-156.1 572.7,-156.1 572.7,-116.5 572.7,-116.5 572.7,-110.5 578.7,-104.5 584.7,-104.5 584.7,-104.5 690.3,-104.5 690.3,-104.5 696.3,-104.5 702.3,-110.5 702.3,-116.5 702.3,-116.5 702.3,-156.1 702.3,-156.1 702.3,-162.1 696.3,-168.1 690.3,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"583.5\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles</text>\n",
       "<text text-anchor=\"start\" x=\"614.25\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>medium_articles&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M702.68,-118.26C711.45,-115.8 720.53,-113.25 729.49,-110.73\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"730.2,-114.16 738.88,-108.09 728.31,-107.42 730.2,-114.16\"/>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset&#45;&gt;medium_articles -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>medium_articles.select_data.dataset&#45;&gt;medium_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M534.51,-136.3C543.53,-136.3 552.44,-136.3 561.01,-136.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"560.96,-139.8 570.96,-136.3 560.96,-132.8 560.96,-139.8\"/>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>_sampled_articles_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"711.8,-86.6 563.2,-86.6 563.2,0 711.8,0 711.8,-86.6\"/>\n",
       "<text text-anchor=\"start\" x=\"589.63\" y=\"-58.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sample_size</text>\n",
       "<text text-anchor=\"start\" x=\"682.88\" y=\"-58.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"577.63\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">max_text_length</text>\n",
       "<text text-anchor=\"start\" x=\"682.88\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"585.88\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">random_state</text>\n",
       "<text text-anchor=\"start\" x=\"682.88\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>_sampled_articles_inputs&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M711.85,-63.48C717.62,-65.06 723.46,-66.67 729.26,-68.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"728.29,-71.63 738.86,-70.91 730.15,-64.88 728.29,-71.63\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"150.68,-222.6 96.68,-222.6 96.68,-186 150.68,-186 150.68,-222.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.68\" y=\"-198.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.1,-277.6C146.1,-277.6 101.25,-277.6 101.25,-277.6 95.25,-277.6 89.25,-271.6 89.25,-265.6 89.25,-265.6 89.25,-253 89.25,-253 89.25,-247 95.25,-241 101.25,-241 101.25,-241 146.1,-241 146.1,-241 152.1,-241 158.1,-247 158.1,-253 158.1,-253 158.1,-265.6 158.1,-265.6 158.1,-271.6 152.1,-277.6 146.1,-277.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.68\" y=\"-253.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x14de1c940>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%incr_cell_to_module ner_module 1 --display\n",
    "\n",
    "from datasets import Dataset\n",
    "from hamilton.function_modifiers import load_from, save_to, source, value\n",
    "\n",
    "@load_from.hf_dataset(\n",
    "    path=value(\"fabiochiu/medium-articles\"),\n",
    "    data_files=value(\"medium_articles.csv\"),\n",
    "    split=value(\"train\"),\n",
    ")\n",
    "def medium_articles(dataset: Dataset) -> Dataset:\n",
    "    \"\"\"Loads medium dataset into a hugging face dataset\"\"\"\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def sampled_articles(\n",
    "    medium_articles: Dataset,\n",
    "    sample_size: int = 104,\n",
    "    random_state: int = 32,\n",
    "    max_text_length: int = 1000,\n",
    ") -> Dataset:\n",
    "    \"\"\"Samples the articles and does some light transformations.\n",
    "    Transformations:\n",
    "     - selects the first 1000 characters of text. This is for performance here. But in real life you'd \\\n",
    "     do something for your use case.\n",
    "      - Joins article title and the text to create one text string.\n",
    "    \"\"\"\n",
    "    # Filter out entries with NaN values in 'text' or 'title' fields\n",
    "    dataset = medium_articles.filter(\n",
    "        lambda example: example[\"text\"] is not None and example[\"title\"] is not None\n",
    "    )\n",
    "\n",
    "    # Shuffle and take the first 10000 samples\n",
    "    dataset = dataset.shuffle(seed=random_state).select(range(sample_size))\n",
    "\n",
    "    # Truncate the 'text' to the first 1000 characters\n",
    "    dataset = dataset.map(lambda example: {\"text\": example[\"text\"][:max_text_length]})\n",
    "\n",
    "    # Concatenate the 'title' and truncated 'text'\n",
    "    dataset = dataset.map(lambda example: {\"title_text\": example[\"title\"] + \". \" + example[\"text\"]})\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79ec8de3",
   "metadata": {},
   "source": [
    "# Create the NER tokenizer and model\n",
    "We now can add to our pipeline loading the tokenizer and model that will extract entities for us from text. The NER model here is finetuned on a BERT-base model.  All the models are loaded from huggingface.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "973ef2f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"881pt\" height=\"529pt\"\n",
       " viewBox=\"0.00 0.00 880.65 529.30\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 525.3)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-525.3 876.65,-525.3 876.65,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"81.25,-383.3 81.25,-513.3 166.1,-513.3 166.1,-383.3 81.25,-383.3\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.68\" y=\"-496\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- sampled_articles -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>sampled_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M860.65,-121.1C860.65,-121.1 752.8,-121.1 752.8,-121.1 746.8,-121.1 740.8,-115.1 740.8,-109.1 740.8,-109.1 740.8,-69.5 740.8,-69.5 740.8,-63.5 746.8,-57.5 752.8,-57.5 752.8,-57.5 860.65,-57.5 860.65,-57.5 866.65,-57.5 872.65,-63.5 872.65,-69.5 872.65,-69.5 872.65,-109.1 872.65,-109.1 872.65,-115.1 866.65,-121.1 860.65,-121.1\"/>\n",
       "<text text-anchor=\"start\" x=\"751.6\" y=\"-98\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">sampled_articles</text>\n",
       "<text text-anchor=\"start\" x=\"783.48\" y=\"-70\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- tokenizer -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>tokenizer</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M466.33,-414.1C466.33,-414.1 344.23,-414.1 344.23,-414.1 338.23,-414.1 332.23,-408.1 332.23,-402.1 332.23,-402.1 332.23,-362.5 332.23,-362.5 332.23,-356.5 338.23,-350.5 344.23,-350.5 344.23,-350.5 466.33,-350.5 466.33,-350.5 472.33,-350.5 478.33,-356.5 478.33,-362.5 478.33,-362.5 478.33,-402.1 478.33,-402.1 478.33,-408.1 472.33,-414.1 466.33,-414.1\"/>\n",
       "<text text-anchor=\"start\" x=\"375.65\" y=\"-391\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">tokenizer</text>\n",
       "<text text-anchor=\"start\" x=\"343.03\" y=\"-363\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedTokenizer</text>\n",
       "</g>\n",
       "<!-- ner_pipeline -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>ner_pipeline</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M676.05,-332.1C676.05,-332.1 598.95,-332.1 598.95,-332.1 592.95,-332.1 586.95,-326.1 586.95,-320.1 586.95,-320.1 586.95,-280.5 586.95,-280.5 586.95,-274.5 592.95,-268.5 598.95,-268.5 598.95,-268.5 676.05,-268.5 676.05,-268.5 682.05,-268.5 688.05,-274.5 688.05,-280.5 688.05,-280.5 688.05,-320.1 688.05,-320.1 688.05,-326.1 682.05,-332.1 676.05,-332.1\"/>\n",
       "<text text-anchor=\"start\" x=\"597.75\" y=\"-309\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">ner_pipeline</text>\n",
       "<text text-anchor=\"start\" x=\"613.5\" y=\"-281\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Pipeline</text>\n",
       "</g>\n",
       "<!-- tokenizer&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>tokenizer&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M478.63,-360.1C496.85,-354.24 516.32,-347.74 534.2,-341.3 547.9,-336.37 562.49,-330.74 576.24,-325.28\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"577.18,-328.67 585.16,-321.7 574.57,-322.17 577.18,-328.67\"/>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>medium_articles.select_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M522.2,-168.1C522.2,-168.1 288.35,-168.1 288.35,-168.1 282.35,-168.1 276.35,-162.1 276.35,-156.1 276.35,-156.1 276.35,-116.5 276.35,-116.5 276.35,-110.5 282.35,-104.5 288.35,-104.5 288.35,-104.5 522.2,-104.5 522.2,-104.5 528.2,-104.5 534.2,-110.5 534.2,-116.5 534.2,-116.5 534.2,-156.1 534.2,-156.1 534.2,-162.1 528.2,-168.1 522.2,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"287.15\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.select_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"382.03\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>medium_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M690.3,-168.1C690.3,-168.1 584.7,-168.1 584.7,-168.1 578.7,-168.1 572.7,-162.1 572.7,-156.1 572.7,-156.1 572.7,-116.5 572.7,-116.5 572.7,-110.5 578.7,-104.5 584.7,-104.5 584.7,-104.5 690.3,-104.5 690.3,-104.5 696.3,-104.5 702.3,-110.5 702.3,-116.5 702.3,-116.5 702.3,-156.1 702.3,-156.1 702.3,-162.1 696.3,-168.1 690.3,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"583.5\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles</text>\n",
       "<text text-anchor=\"start\" x=\"614.25\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset&#45;&gt;medium_articles -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>medium_articles.select_data.dataset&#45;&gt;medium_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M534.51,-136.3C543.53,-136.3 552.44,-136.3 561.01,-136.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"560.96,-139.8 570.96,-136.3 560.96,-132.8 560.96,-139.8\"/>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>medium_articles.load_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M235.35,-168.1C235.35,-168.1 12,-168.1 12,-168.1 6,-168.1 0,-162.1 0,-156.1 0,-156.1 0,-116.5 0,-116.5 0,-110.5 6,-104.5 12,-104.5 12,-104.5 235.35,-104.5 235.35,-104.5 241.35,-104.5 247.35,-110.5 247.35,-116.5 247.35,-116.5 247.35,-156.1 247.35,-156.1 247.35,-162.1 241.35,-168.1 235.35,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"10.8\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.load_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"106.8\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Tuple</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M247.72,-136.3C253.29,-136.3 258.89,-136.3 264.49,-136.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"264.47,-139.8 274.47,-136.3 264.47,-132.8 264.47,-139.8\"/>\n",
       "</g>\n",
       "<!-- NER_model_id -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>NER_model_id</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M170.48,-373.1C170.48,-373.1 76.88,-373.1 76.88,-373.1 70.88,-373.1 64.88,-367.1 64.88,-361.1 64.88,-361.1 64.88,-321.5 64.88,-321.5 64.88,-315.5 70.88,-309.5 76.88,-309.5 76.88,-309.5 170.48,-309.5 170.48,-309.5 176.48,-309.5 182.48,-315.5 182.48,-321.5 182.48,-321.5 182.48,-361.1 182.48,-361.1 182.48,-367.1 176.48,-373.1 170.48,-373.1\"/>\n",
       "<text text-anchor=\"start\" x=\"75.68\" y=\"-350\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">NER_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"116.18\" y=\"-322\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;tokenizer -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;tokenizer</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.95,-349.85C222.77,-355.68 276.17,-363.51 320.53,-370.02\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"320,-373.48 330.4,-371.47 321.01,-366.55 320,-373.48\"/>\n",
       "</g>\n",
       "<!-- model -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M456.2,-332.1C456.2,-332.1 354.35,-332.1 354.35,-332.1 348.35,-332.1 342.35,-326.1 342.35,-320.1 342.35,-320.1 342.35,-280.5 342.35,-280.5 342.35,-274.5 348.35,-268.5 354.35,-268.5 354.35,-268.5 456.2,-268.5 456.2,-268.5 462.2,-268.5 468.2,-274.5 468.2,-280.5 468.2,-280.5 468.2,-320.1 468.2,-320.1 468.2,-326.1 462.2,-332.1 456.2,-332.1\"/>\n",
       "<text text-anchor=\"start\" x=\"385.03\" y=\"-309\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model</text>\n",
       "<text text-anchor=\"start\" x=\"353.15\" y=\"-281\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedModel</text>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;model -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.95,-332.75C225.97,-326.45 284.84,-317.81 331.04,-311.04\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"331.31,-314.54 340.7,-309.62 330.3,-307.61 331.31,-314.54\"/>\n",
       "</g>\n",
       "<!-- model&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>model&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M468.58,-300.3C501.69,-300.3 542.25,-300.3 575.22,-300.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"575.02,-303.8 585.02,-300.3 575.02,-296.8 575.02,-303.8\"/>\n",
       "</g>\n",
       "<!-- device -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>device</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M425.08,-250.1C425.08,-250.1 385.48,-250.1 385.48,-250.1 379.48,-250.1 373.48,-244.1 373.48,-238.1 373.48,-238.1 373.48,-198.5 373.48,-198.5 373.48,-192.5 379.48,-186.5 385.48,-186.5 385.48,-186.5 425.08,-186.5 425.08,-186.5 431.08,-186.5 437.08,-192.5 437.08,-198.5 437.08,-198.5 437.08,-238.1 437.08,-238.1 437.08,-244.1 431.08,-250.1 425.08,-250.1\"/>\n",
       "<text text-anchor=\"start\" x=\"384.28\" y=\"-227\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">device</text>\n",
       "<text text-anchor=\"start\" x=\"397.78\" y=\"-199\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- device&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>device&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M437.5,-227.63C463.53,-235.57 501.51,-247.53 534.2,-259.3 547.9,-264.23 562.49,-269.86 576.24,-275.32\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"574.57,-278.43 585.16,-278.9 577.18,-271.93 574.57,-278.43\"/>\n",
       "</g>\n",
       "<!-- medium_articles&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>medium_articles&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M702.68,-118.26C711.45,-115.8 720.53,-113.25 729.49,-110.73\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"730.2,-114.16 738.88,-108.09 728.31,-107.42 730.2,-114.16\"/>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>_sampled_articles_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"711.8,-86.6 563.2,-86.6 563.2,0 711.8,0 711.8,-86.6\"/>\n",
       "<text text-anchor=\"start\" x=\"589.63\" y=\"-58.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sample_size</text>\n",
       "<text text-anchor=\"start\" x=\"682.88\" y=\"-58.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"577.63\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">max_text_length</text>\n",
       "<text text-anchor=\"start\" x=\"682.88\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"585.88\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">random_state</text>\n",
       "<text text-anchor=\"start\" x=\"682.88\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>_sampled_articles_inputs&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M711.85,-63.48C717.62,-65.06 723.46,-66.67 729.26,-68.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"728.29,-71.63 738.86,-70.91 730.15,-64.88 728.29,-71.63\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"150.68,-427.6 96.68,-427.6 96.68,-391 150.68,-391 150.68,-427.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.68\" y=\"-403.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.1,-482.6C146.1,-482.6 101.25,-482.6 101.25,-482.6 95.25,-482.6 89.25,-476.6 89.25,-470.6 89.25,-470.6 89.25,-458 89.25,-458 89.25,-452 95.25,-446 101.25,-446 101.25,-446 146.1,-446 146.1,-446 152.1,-446 158.1,-452 158.1,-458 158.1,-458 158.1,-470.6 158.1,-470.6 158.1,-476.6 152.1,-482.6 146.1,-482.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.68\" y=\"-458.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x14de1ea40>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%incr_cell_to_module ner_module 2 --display\n",
    "\n",
    "import torch\n",
    "from transformers import (\n",
    "    AutoModelForTokenClassification,\n",
    "    AutoTokenizer,\n",
    "    PreTrainedModel,\n",
    "    PreTrainedTokenizer,\n",
    "    pipeline,\n",
    ")\n",
    "from transformers.pipelines import base\n",
    "\n",
    "def device() -> str:\n",
    "    \"\"\"Whether this is a CUDA or CPU enabled device.\"\"\"\n",
    "    return \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "\n",
    "def NER_model_id() -> str:\n",
    "    \"\"\"Model ID to use\n",
    "    To extract named entities, we will use a NER model finetuned on a BERT-base model.\n",
    "    The model can be loaded from the HuggingFace model hub.\n",
    "    Use `overrides={\"NER_model_id\": VALUE}` to switch this without changing code.\n",
    "    \"\"\"\n",
    "    return \"dslim/bert-base-NER\"\n",
    "\n",
    "\n",
    "def tokenizer(NER_model_id: str) -> PreTrainedTokenizer:\n",
    "    \"\"\"Loads the tokenizer for the NER model ID from huggingface\"\"\"\n",
    "    return AutoTokenizer.from_pretrained(NER_model_id)\n",
    "\n",
    "\n",
    "def model(NER_model_id: str) -> PreTrainedModel:\n",
    "    \"\"\"Loads the NER model from huggingface\"\"\"\n",
    "    return AutoModelForTokenClassification.from_pretrained(NER_model_id)\n",
    "\n",
    "\n",
    "def ner_pipeline(\n",
    "    model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device: str\n",
    ") -> base.Pipeline:\n",
    "    \"\"\"Loads the tokenizer and model into a NER pipeline. That is it combines them.\"\"\"\n",
    "    device_no = torch.cuda.current_device() if device == \"cuda\" else None\n",
    "    return pipeline(\n",
    "        \"ner\", model=model, tokenizer=tokenizer, aggregation_strategy=\"max\", device=device_no\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d570570e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stefankrawczyk/.pyenv/versions/3.10.4/envs/ner-example-py310/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[[{'entity_group': 'ORG',\n",
       "   'score': 0.9978863,\n",
       "   'word': 'Mars Rover',\n",
       "   'start': 4,\n",
       "   'end': 14},\n",
       "  {'entity_group': 'ORG',\n",
       "   'score': 0.99731904,\n",
       "   'word': 'NASA',\n",
       "   'start': 20,\n",
       "   'end': 24}]]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# this is what the NER pipeline produces\n",
    "text = \"The Mars Rover from NASA reached the red planet yesterday.\"\n",
    "ner_module.ner_pipeline(model(NER_model_id()), tokenizer(NER_model_id()), \"cpu\")([text])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b8a4a7a",
   "metadata": {},
   "source": [
    "# Create the embedding model\n",
    "Next we load the retriever model that will create embeddings, i.e. a vector/list of floats, that encode our text. Specifically it will embed passages (article title + first 1000 characters) and also be used to create an embedding from the search query that will be provided at inference time. It creates embeddings such that queries and passages with similar meanings are close in the vector space. We will use a sentence-transformer model as our retriever. The model can be loaded using the following code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c20633ca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"885pt\" height=\"591pt\"\n",
       " viewBox=\"0.00 0.00 884.90 591.30\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 587.3)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-587.3 880.9,-587.3 880.9,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"81.25,-445.3 81.25,-575.3 166.1,-575.3 166.1,-445.3 81.25,-445.3\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-558\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- sampled_articles -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>sampled_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M864.9,-121.1C864.9,-121.1 757.05,-121.1 757.05,-121.1 751.05,-121.1 745.05,-115.1 745.05,-109.1 745.05,-109.1 745.05,-69.5 745.05,-69.5 745.05,-63.5 751.05,-57.5 757.05,-57.5 757.05,-57.5 864.9,-57.5 864.9,-57.5 870.9,-57.5 876.9,-63.5 876.9,-69.5 876.9,-69.5 876.9,-109.1 876.9,-109.1 876.9,-115.1 870.9,-121.1 864.9,-121.1\"/>\n",
       "<text text-anchor=\"start\" x=\"755.85\" y=\"-98\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">sampled_articles</text>\n",
       "<text text-anchor=\"start\" x=\"787.73\" y=\"-70\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- tokenizer -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>tokenizer</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M466.33,-476.1C466.33,-476.1 344.23,-476.1 344.23,-476.1 338.23,-476.1 332.23,-470.1 332.23,-464.1 332.23,-464.1 332.23,-424.5 332.23,-424.5 332.23,-418.5 338.23,-412.5 344.23,-412.5 344.23,-412.5 466.33,-412.5 466.33,-412.5 472.33,-412.5 478.33,-418.5 478.33,-424.5 478.33,-424.5 478.33,-464.1 478.33,-464.1 478.33,-470.1 472.33,-476.1 466.33,-476.1\"/>\n",
       "<text text-anchor=\"start\" x=\"375.65\" y=\"-453\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">tokenizer</text>\n",
       "<text text-anchor=\"start\" x=\"343.03\" y=\"-425\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedTokenizer</text>\n",
       "</g>\n",
       "<!-- ner_pipeline -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>ner_pipeline</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M678.17,-394.1C678.17,-394.1 601.08,-394.1 601.08,-394.1 595.08,-394.1 589.08,-388.1 589.08,-382.1 589.08,-382.1 589.08,-342.5 589.08,-342.5 589.08,-336.5 595.08,-330.5 601.08,-330.5 601.08,-330.5 678.17,-330.5 678.17,-330.5 684.17,-330.5 690.17,-336.5 690.17,-342.5 690.17,-342.5 690.17,-382.1 690.17,-382.1 690.17,-388.1 684.17,-394.1 678.17,-394.1\"/>\n",
       "<text text-anchor=\"start\" x=\"599.88\" y=\"-371\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">ner_pipeline</text>\n",
       "<text text-anchor=\"start\" x=\"615.62\" y=\"-343\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Pipeline</text>\n",
       "</g>\n",
       "<!-- tokenizer&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>tokenizer&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M478.6,-422.01C496.82,-416.15 516.29,-409.68 534.2,-403.3 548.46,-398.22 563.68,-392.45 577.96,-386.87\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"579.24,-390.13 587.27,-383.22 576.68,-383.62 579.24,-390.13\"/>\n",
       "</g>\n",
       "<!-- retriever -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>retriever</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M704.05,-281.1C704.05,-281.1 575.2,-281.1 575.2,-281.1 569.2,-281.1 563.2,-275.1 563.2,-269.1 563.2,-269.1 563.2,-229.5 563.2,-229.5 563.2,-223.5 569.2,-217.5 575.2,-217.5 575.2,-217.5 704.05,-217.5 704.05,-217.5 710.05,-217.5 716.05,-223.5 716.05,-229.5 716.05,-229.5 716.05,-269.1 716.05,-269.1 716.05,-275.1 710.05,-281.1 704.05,-281.1\"/>\n",
       "<text text-anchor=\"start\" x=\"612.62\" y=\"-258\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">retriever</text>\n",
       "<text text-anchor=\"start\" x=\"574\" y=\"-230\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">SentenceTransformer</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>medium_articles.select_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M522.2,-168.1C522.2,-168.1 288.35,-168.1 288.35,-168.1 282.35,-168.1 276.35,-162.1 276.35,-156.1 276.35,-156.1 276.35,-116.5 276.35,-116.5 276.35,-110.5 282.35,-104.5 288.35,-104.5 288.35,-104.5 522.2,-104.5 522.2,-104.5 528.2,-104.5 534.2,-110.5 534.2,-116.5 534.2,-116.5 534.2,-156.1 534.2,-156.1 534.2,-162.1 528.2,-168.1 522.2,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"287.15\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.select_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"382.03\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>medium_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M692.42,-168.1C692.42,-168.1 586.83,-168.1 586.83,-168.1 580.83,-168.1 574.83,-162.1 574.83,-156.1 574.83,-156.1 574.83,-116.5 574.83,-116.5 574.83,-110.5 580.83,-104.5 586.83,-104.5 586.83,-104.5 692.42,-104.5 692.42,-104.5 698.42,-104.5 704.42,-110.5 704.42,-116.5 704.42,-116.5 704.42,-156.1 704.42,-156.1 704.42,-162.1 698.42,-168.1 692.42,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"585.62\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles</text>\n",
       "<text text-anchor=\"start\" x=\"616.38\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset&#45;&gt;medium_articles -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>medium_articles.select_data.dataset&#45;&gt;medium_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M534.68,-136.3C544.35,-136.3 553.9,-136.3 563.07,-136.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"562.94,-139.8 572.94,-136.3 562.94,-132.8 562.94,-139.8\"/>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>medium_articles.load_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M235.35,-168.1C235.35,-168.1 12,-168.1 12,-168.1 6,-168.1 0,-162.1 0,-156.1 0,-156.1 0,-116.5 0,-116.5 0,-110.5 6,-104.5 12,-104.5 12,-104.5 235.35,-104.5 235.35,-104.5 241.35,-104.5 247.35,-110.5 247.35,-116.5 247.35,-116.5 247.35,-156.1 247.35,-156.1 247.35,-162.1 241.35,-168.1 235.35,-168.1\"/>\n",
       "<text text-anchor=\"start\" x=\"10.8\" y=\"-145\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.load_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"106.8\" y=\"-117\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Tuple</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M247.72,-136.3C253.29,-136.3 258.89,-136.3 264.49,-136.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"264.47,-139.8 274.47,-136.3 264.47,-132.8 264.47,-139.8\"/>\n",
       "</g>\n",
       "<!-- NER_model_id -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>NER_model_id</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M170.47,-435.1C170.47,-435.1 76.87,-435.1 76.87,-435.1 70.87,-435.1 64.87,-429.1 64.87,-423.1 64.87,-423.1 64.87,-383.5 64.87,-383.5 64.87,-377.5 70.87,-371.5 76.87,-371.5 76.87,-371.5 170.47,-371.5 170.47,-371.5 176.47,-371.5 182.47,-377.5 182.47,-383.5 182.47,-383.5 182.47,-423.1 182.47,-423.1 182.47,-429.1 176.47,-435.1 170.47,-435.1\"/>\n",
       "<text text-anchor=\"start\" x=\"75.67\" y=\"-412\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">NER_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"116.17\" y=\"-384\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;tokenizer -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;tokenizer</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.95,-411.85C222.77,-417.68 276.17,-425.51 320.53,-432.02\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"320,-435.48 330.4,-433.47 321.01,-428.55 320,-435.48\"/>\n",
       "</g>\n",
       "<!-- model -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M456.2,-394.1C456.2,-394.1 354.35,-394.1 354.35,-394.1 348.35,-394.1 342.35,-388.1 342.35,-382.1 342.35,-382.1 342.35,-342.5 342.35,-342.5 342.35,-336.5 348.35,-330.5 354.35,-330.5 354.35,-330.5 456.2,-330.5 456.2,-330.5 462.2,-330.5 468.2,-336.5 468.2,-342.5 468.2,-342.5 468.2,-382.1 468.2,-382.1 468.2,-388.1 462.2,-394.1 456.2,-394.1\"/>\n",
       "<text text-anchor=\"start\" x=\"385.03\" y=\"-371\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model</text>\n",
       "<text text-anchor=\"start\" x=\"353.15\" y=\"-343\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedModel</text>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;model -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M182.95,-394.75C225.97,-388.45 284.84,-379.81 331.04,-373.04\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"331.31,-376.54 340.7,-371.62 330.3,-369.61 331.31,-376.54\"/>\n",
       "</g>\n",
       "<!-- model&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>model&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M468.54,-362.3C502.21,-362.3 543.64,-362.3 577.18,-362.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"577.18,-365.8 587.18,-362.3 577.18,-358.8 577.18,-365.8\"/>\n",
       "</g>\n",
       "<!-- device -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>device</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M425.08,-312.1C425.08,-312.1 385.48,-312.1 385.48,-312.1 379.48,-312.1 373.48,-306.1 373.48,-300.1 373.48,-300.1 373.48,-260.5 373.48,-260.5 373.48,-254.5 379.48,-248.5 385.48,-248.5 385.48,-248.5 425.08,-248.5 425.08,-248.5 431.08,-248.5 437.08,-254.5 437.08,-260.5 437.08,-260.5 437.08,-300.1 437.08,-300.1 437.08,-306.1 431.08,-312.1 425.08,-312.1\"/>\n",
       "<text text-anchor=\"start\" x=\"384.28\" y=\"-289\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">device</text>\n",
       "<text text-anchor=\"start\" x=\"397.78\" y=\"-261\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- device&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>device&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M437.48,-289.69C463.5,-297.66 501.46,-309.64 534.2,-321.3 548.46,-326.38 563.68,-332.15 577.96,-337.73\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"576.68,-340.98 587.27,-341.38 579.24,-334.47 576.68,-340.98\"/>\n",
       "</g>\n",
       "<!-- device&#45;&gt;retriever -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>device&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M437.13,-276.18C466.35,-272.28 511.63,-266.24 551.87,-260.87\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"552.12,-264.37 561.57,-259.58 551.19,-257.43 552.12,-264.37\"/>\n",
       "</g>\n",
       "<!-- medium_articles&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>medium_articles&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M704.65,-118.53C714.18,-115.89 724.08,-113.14 733.82,-110.43\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"734.58,-113.86 743.28,-107.81 732.7,-107.11 734.58,-113.86\"/>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>_sampled_articles_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"713.92,-86.6 565.33,-86.6 565.33,0 713.92,0 713.92,-86.6\"/>\n",
       "<text text-anchor=\"start\" x=\"591.75\" y=\"-58.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sample_size</text>\n",
       "<text text-anchor=\"start\" x=\"685\" y=\"-58.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"579.75\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">max_text_length</text>\n",
       "<text text-anchor=\"start\" x=\"685\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"588\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">random_state</text>\n",
       "<text text-anchor=\"start\" x=\"685\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>_sampled_articles_inputs&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M714.41,-63.34C720.86,-65.1 727.4,-66.87 733.87,-68.63\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"732.64,-71.92 743.21,-71.17 734.48,-65.17 732.64,-71.92\"/>\n",
       "</g>\n",
       "<!-- _retriever_inputs -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>_retriever_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"486.83,-230.6 323.73,-230.6 323.73,-186 486.83,-186 486.83,-230.6\"/>\n",
       "<text text-anchor=\"start\" x=\"338.53\" y=\"-202.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">retriever_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"457.03\" y=\"-202.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- _retriever_inputs&#45;&gt;retriever -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>_retriever_inputs&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M487.21,-222.58C508.08,-226.27 530.63,-230.24 551.79,-233.98\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"551.1,-237.41 561.55,-235.7 552.31,-230.52 551.1,-237.41\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"150.67,-489.6 96.67,-489.6 96.67,-453 150.67,-453 150.67,-489.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-465.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.1,-544.6C146.1,-544.6 101.25,-544.6 101.25,-544.6 95.25,-544.6 89.25,-538.6 89.25,-532.6 89.25,-532.6 89.25,-520 89.25,-520 89.25,-514 95.25,-508 101.25,-508 101.25,-508 146.1,-508 146.1,-508 152.1,-508 158.1,-514 158.1,-520 158.1,-520 158.1,-532.6 158.1,-532.6 158.1,-538.6 152.1,-544.6 146.1,-544.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-520.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x155e0bb80>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%incr_cell_to_module ner_module 3 --display\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "def retriever(\n",
    "    device: str, retriever_model_id: str = \"flax-sentence-embeddings/all_datasets_v3_mpnet-base\"\n",
    ") -> SentenceTransformer:\n",
    "    \"\"\"Our retriever model to create embeddings.\n",
    "\n",
    "    A retriever model is used to embed passages (article title + first 1000 characters)\n",
    "     and queries. It creates embeddings such that queries and passages with similar\n",
    "     meanings are close in the vector space. We will use a sentence-transformer model\n",
    "      as our retriever. The model can be loaded as follows:\n",
    "    \"\"\"\n",
    "    return SentenceTransformer(retriever_model_id, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0479f207",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stefankrawczyk/.pyenv/versions/3.10.4/envs/ner-example-py310/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([ 0.03609628, -0.03315403,  0.00881905,  0.04301339,  0.00257134,\n",
       "       -0.00996292,  0.02379813,  0.03957068, -0.03063051, -0.00725629],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# what the embedding model produces -- just show first 10 numbers\n",
    "ner_module.retriever(\"cpu\").encode([\"this is some text\"])[0][0:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "158ca115",
   "metadata": {},
   "source": [
    "# Extracting entities & creating embeddings\n",
    "Next let’s put this all together to extract entities & embed the documents.\n",
    "\n",
    "We do this by using Huggingface dataset’s map functionality. Using this ensures that data can be loaded into batches to ensure that data hungry GPUs are appropriately fed with data. What you need to provide to this function is a function that contains the logic you want to apply to it. So below we create some helper functions for that purpose. This also helps ensure unit testability, while also keeping the code clean. We then wire these helper functions up to the map functions to create the vector embedding and named_entities columns on the dataset. \n",
    "\n",
    "We then prepare this for loading into lancedb by using the `@save_to` data saver. This uses batching to write chunks of the dataset to lancedb."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0e72c866",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"1379pt\" height=\"691pt\"\n",
       " viewBox=\"0.00 0.00 1379.10 691.30\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 687.3)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-687.3 1375.1,-687.3 1375.1,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"69.62,-486.3 69.62,-675.3 177.72,-675.3 177.72,-486.3 69.62,-486.3\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-658\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- columns_of_interest -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>columns_of_interest</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1099.12,-290.1C1099.12,-290.1 968.78,-290.1 968.78,-290.1 962.78,-290.1 956.78,-284.1 956.78,-278.1 956.78,-278.1 956.78,-238.5 956.78,-238.5 956.78,-232.5 962.78,-226.5 968.78,-226.5 968.78,-226.5 1099.12,-226.5 1099.12,-226.5 1105.12,-226.5 1111.12,-232.5 1111.12,-238.5 1111.12,-238.5 1111.12,-278.1 1111.12,-278.1 1111.12,-284.1 1105.12,-290.1 1099.12,-290.1\"/>\n",
       "<text text-anchor=\"start\" x=\"967.58\" y=\"-267\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">columns_of_interest</text>\n",
       "<text text-anchor=\"start\" x=\"1025.7\" y=\"-239\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">list</text>\n",
       "</g>\n",
       "<!-- load_into_lancedb -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>load_into_lancedb</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1371.1,-212.07C1371.1,-216.46 1323.73,-220.03 1265.43,-220.03 1207.12,-220.03 1159.75,-216.46 1159.75,-212.07 1159.75,-212.07 1159.75,-140.53 1159.75,-140.53 1159.75,-136.14 1207.12,-132.58 1265.43,-132.58 1323.73,-132.58 1371.1,-136.14 1371.1,-140.53 1371.1,-140.53 1371.1,-212.07 1371.1,-212.07\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1371.1,-212.07C1371.1,-207.69 1323.73,-204.12 1265.43,-204.12 1207.12,-204.12 1159.75,-207.69 1159.75,-212.07\"/>\n",
       "<text text-anchor=\"start\" x=\"1206.18\" y=\"-185\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">load_into_lancedb</text>\n",
       "<text text-anchor=\"start\" x=\"1170.55\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">HuggingFaceDSLanceDBSaver</text>\n",
       "</g>\n",
       "<!-- columns_of_interest&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>columns_of_interest&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1111.36,-231C1123.38,-226.7 1136.03,-222.18 1148.7,-217.65\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1149.67,-221.02 1157.91,-214.36 1147.32,-214.43 1149.67,-221.02\"/>\n",
       "</g>\n",
       "<!-- final_dataset -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>final_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1074,-208.1C1074,-208.1 993.9,-208.1 993.9,-208.1 987.9,-208.1 981.9,-202.1 981.9,-196.1 981.9,-196.1 981.9,-156.5 981.9,-156.5 981.9,-150.5 987.9,-144.5 993.9,-144.5 993.9,-144.5 1074,-144.5 1074,-144.5 1080,-144.5 1086,-150.5 1086,-156.5 1086,-156.5 1086,-196.1 1086,-196.1 1086,-202.1 1080,-208.1 1074,-208.1\"/>\n",
       "<text text-anchor=\"start\" x=\"992.7\" y=\"-185\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">final_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"1010.7\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- final_dataset&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>final_dataset&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1086.45,-176.3C1104.9,-176.3 1126.53,-176.3 1148.23,-176.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1148.05,-179.8 1158.05,-176.3 1148.05,-172.8 1148.05,-179.8\"/>\n",
       "</g>\n",
       "<!-- tokenizer -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>tokenizer</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M705.8,-290.1C705.8,-290.1 583.7,-290.1 583.7,-290.1 577.7,-290.1 571.7,-284.1 571.7,-278.1 571.7,-278.1 571.7,-238.5 571.7,-238.5 571.7,-232.5 577.7,-226.5 583.7,-226.5 583.7,-226.5 705.8,-226.5 705.8,-226.5 711.8,-226.5 717.8,-232.5 717.8,-238.5 717.8,-238.5 717.8,-278.1 717.8,-278.1 717.8,-284.1 711.8,-290.1 705.8,-290.1\"/>\n",
       "<text text-anchor=\"start\" x=\"615.13\" y=\"-267\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">tokenizer</text>\n",
       "<text text-anchor=\"start\" x=\"582.5\" y=\"-239\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedTokenizer</text>\n",
       "</g>\n",
       "<!-- ner_pipeline -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>ner_pipeline</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M870.28,-208.1C870.28,-208.1 793.18,-208.1 793.18,-208.1 787.18,-208.1 781.18,-202.1 781.18,-196.1 781.18,-196.1 781.18,-156.5 781.18,-156.5 781.18,-150.5 787.18,-144.5 793.18,-144.5 793.18,-144.5 870.28,-144.5 870.28,-144.5 876.28,-144.5 882.28,-150.5 882.28,-156.5 882.28,-156.5 882.28,-196.1 882.28,-196.1 882.28,-202.1 876.28,-208.1 870.28,-208.1\"/>\n",
       "<text text-anchor=\"start\" x=\"791.98\" y=\"-185\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">ner_pipeline</text>\n",
       "<text text-anchor=\"start\" x=\"807.73\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Pipeline</text>\n",
       "</g>\n",
       "<!-- tokenizer&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>tokenizer&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M718.06,-226.24C735.43,-218.54 753.8,-210.4 770.55,-202.98\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"771.64,-206.32 779.36,-199.07 768.8,-199.92 771.64,-206.32\"/>\n",
       "</g>\n",
       "<!-- model -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M695.68,-208.1C695.68,-208.1 593.83,-208.1 593.83,-208.1 587.83,-208.1 581.83,-202.1 581.83,-196.1 581.83,-196.1 581.83,-156.5 581.83,-156.5 581.83,-150.5 587.83,-144.5 593.83,-144.5 593.83,-144.5 695.68,-144.5 695.68,-144.5 701.68,-144.5 707.68,-150.5 707.68,-156.5 707.68,-156.5 707.68,-196.1 707.68,-196.1 707.68,-202.1 701.68,-208.1 695.68,-208.1\"/>\n",
       "<text text-anchor=\"start\" x=\"624.5\" y=\"-185\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model</text>\n",
       "<text text-anchor=\"start\" x=\"592.63\" y=\"-157\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedModel</text>\n",
       "</g>\n",
       "<!-- model&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>model&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M708.12,-176.3C727.93,-176.3 749.82,-176.3 769.55,-176.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"769.37,-179.8 779.37,-176.3 769.37,-172.8 769.37,-179.8\"/>\n",
       "</g>\n",
       "<!-- sampled_articles -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>sampled_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M885.65,-383.1C885.65,-383.1 777.8,-383.1 777.8,-383.1 771.8,-383.1 765.8,-377.1 765.8,-371.1 765.8,-371.1 765.8,-331.5 765.8,-331.5 765.8,-325.5 771.8,-319.5 777.8,-319.5 777.8,-319.5 885.65,-319.5 885.65,-319.5 891.65,-319.5 897.65,-325.5 897.65,-331.5 897.65,-331.5 897.65,-371.1 897.65,-371.1 897.65,-377.1 891.65,-383.1 885.65,-383.1\"/>\n",
       "<text text-anchor=\"start\" x=\"776.6\" y=\"-360\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">sampled_articles</text>\n",
       "<text text-anchor=\"start\" x=\"808.48\" y=\"-332\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- sampled_articles&#45;&gt;final_dataset -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>sampled_articles&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M851.22,-319.08C869.7,-289.28 900.59,-245.48 937.15,-217.3 947.35,-209.44 959.28,-202.78 971.19,-197.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"972.46,-200.53 980.22,-193.32 969.65,-194.12 972.46,-200.53\"/>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>medium_articles.load_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M235.35,-476.1C235.35,-476.1 12,-476.1 12,-476.1 6,-476.1 0,-470.1 0,-464.1 0,-464.1 0,-424.5 0,-424.5 0,-418.5 6,-412.5 12,-412.5 12,-412.5 235.35,-412.5 235.35,-412.5 241.35,-412.5 247.35,-418.5 247.35,-424.5 247.35,-424.5 247.35,-464.1 247.35,-464.1 247.35,-470.1 241.35,-476.1 235.35,-476.1\"/>\n",
       "<text text-anchor=\"start\" x=\"10.8\" y=\"-453\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.load_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"106.8\" y=\"-425\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Tuple</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>medium_articles.select_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M522.2,-476.1C522.2,-476.1 288.35,-476.1 288.35,-476.1 282.35,-476.1 276.35,-470.1 276.35,-464.1 276.35,-464.1 276.35,-424.5 276.35,-424.5 276.35,-418.5 282.35,-412.5 288.35,-412.5 288.35,-412.5 522.2,-412.5 522.2,-412.5 528.2,-412.5 534.2,-418.5 534.2,-424.5 534.2,-424.5 534.2,-464.1 534.2,-464.1 534.2,-470.1 528.2,-476.1 522.2,-476.1\"/>\n",
       "<text text-anchor=\"start\" x=\"287.15\" y=\"-453\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.select_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"382.03\" y=\"-425\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M247.72,-444.3C253.29,-444.3 258.89,-444.3 264.49,-444.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"264.47,-447.8 274.47,-444.3 264.47,-440.8 264.47,-447.8\"/>\n",
       "</g>\n",
       "<!-- medium_articles -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>medium_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M697.55,-476.1C697.55,-476.1 591.95,-476.1 591.95,-476.1 585.95,-476.1 579.95,-470.1 579.95,-464.1 579.95,-464.1 579.95,-424.5 579.95,-424.5 579.95,-418.5 585.95,-412.5 591.95,-412.5 591.95,-412.5 697.55,-412.5 697.55,-412.5 703.55,-412.5 709.55,-418.5 709.55,-424.5 709.55,-424.5 709.55,-464.1 709.55,-464.1 709.55,-470.1 703.55,-476.1 697.55,-476.1\"/>\n",
       "<text text-anchor=\"start\" x=\"590.75\" y=\"-453\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles</text>\n",
       "<text text-anchor=\"start\" x=\"621.5\" y=\"-425\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset&#45;&gt;medium_articles -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>medium_articles.select_data.dataset&#45;&gt;medium_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M534.4,-444.3C546.02,-444.3 557.53,-444.3 568.46,-444.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"568.12,-447.8 578.12,-444.3 568.12,-440.8 568.12,-447.8\"/>\n",
       "</g>\n",
       "<!-- ner_pipeline&#45;&gt;final_dataset -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>ner_pipeline&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M882.73,-176.3C909.2,-176.3 941.97,-176.3 970.14,-176.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"969.98,-179.8 979.98,-176.3 969.98,-172.8 969.98,-179.8\"/>\n",
       "</g>\n",
       "<!-- retriever -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>retriever</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M896.15,-126.1C896.15,-126.1 767.3,-126.1 767.3,-126.1 761.3,-126.1 755.3,-120.1 755.3,-114.1 755.3,-114.1 755.3,-74.5 755.3,-74.5 755.3,-68.5 761.3,-62.5 767.3,-62.5 767.3,-62.5 896.15,-62.5 896.15,-62.5 902.15,-62.5 908.15,-68.5 908.15,-74.5 908.15,-74.5 908.15,-114.1 908.15,-114.1 908.15,-120.1 902.15,-126.1 896.15,-126.1\"/>\n",
       "<text text-anchor=\"start\" x=\"804.73\" y=\"-103\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">retriever</text>\n",
       "<text text-anchor=\"start\" x=\"766.1\" y=\"-75\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">SentenceTransformer</text>\n",
       "</g>\n",
       "<!-- retriever&#45;&gt;final_dataset -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>retriever&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M908.42,-125.3C929.07,-133.76 951.2,-142.82 970.99,-150.93\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"969.43,-154.07 980.01,-154.62 972.08,-147.59 969.43,-154.07\"/>\n",
       "</g>\n",
       "<!-- NER_model_id -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>NER_model_id</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M452.08,-249.1C452.08,-249.1 358.48,-249.1 358.48,-249.1 352.48,-249.1 346.48,-243.1 346.48,-237.1 346.48,-237.1 346.48,-197.5 346.48,-197.5 346.48,-191.5 352.48,-185.5 358.48,-185.5 358.48,-185.5 452.08,-185.5 452.08,-185.5 458.08,-185.5 464.08,-191.5 464.08,-197.5 464.08,-197.5 464.08,-237.1 464.08,-237.1 464.08,-243.1 458.08,-249.1 452.08,-249.1\"/>\n",
       "<text text-anchor=\"start\" x=\"357.28\" y=\"-226\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">NER_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"397.78\" y=\"-198\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;tokenizer -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;tokenizer</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M464.37,-227.33C493.23,-232.31 528.63,-238.43 560.34,-243.9\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"559.44,-247.3 569.89,-245.55 560.63,-240.4 559.44,-247.3\"/>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;model -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M464.37,-207.27C496.32,-201.75 536.28,-194.85 570.37,-188.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"570.94,-192.42 580.2,-187.27 569.75,-185.53 570.94,-192.42\"/>\n",
       "</g>\n",
       "<!-- device -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>device</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M664.55,-126.1C664.55,-126.1 624.95,-126.1 624.95,-126.1 618.95,-126.1 612.95,-120.1 612.95,-114.1 612.95,-114.1 612.95,-74.5 612.95,-74.5 612.95,-68.5 618.95,-62.5 624.95,-62.5 624.95,-62.5 664.55,-62.5 664.55,-62.5 670.55,-62.5 676.55,-68.5 676.55,-74.5 676.55,-74.5 676.55,-114.1 676.55,-114.1 676.55,-120.1 670.55,-126.1 664.55,-126.1\"/>\n",
       "<text text-anchor=\"start\" x=\"623.75\" y=\"-103\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">device</text>\n",
       "<text text-anchor=\"start\" x=\"637.25\" y=\"-75\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- device&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>device&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M676.71,-108.02C702.21,-119.33 739.14,-135.7 770.44,-149.58\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"768.87,-152.71 779.43,-153.56 771.71,-146.31 768.87,-152.71\"/>\n",
       "</g>\n",
       "<!-- device&#45;&gt;retriever -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>device&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M676.71,-94.3C695.18,-94.3 719.64,-94.3 743.61,-94.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"743.42,-97.8 753.42,-94.3 743.42,-90.8 743.42,-97.8\"/>\n",
       "</g>\n",
       "<!-- medium_articles&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>medium_articles&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M709.93,-412.03C724.99,-404.46 741.17,-396.33 756.62,-388.56\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"757.95,-391.81 765.31,-384.19 754.8,-385.55 757.95,-391.81\"/>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>_sampled_articles_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"719.05,-394.6 570.45,-394.6 570.45,-308 719.05,-308 719.05,-394.6\"/>\n",
       "<text text-anchor=\"start\" x=\"596.88\" y=\"-366.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sample_size</text>\n",
       "<text text-anchor=\"start\" x=\"690.13\" y=\"-366.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"584.88\" y=\"-345.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">max_text_length</text>\n",
       "<text text-anchor=\"start\" x=\"690.13\" y=\"-345.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"593.13\" y=\"-324.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">random_state</text>\n",
       "<text text-anchor=\"start\" x=\"690.13\" y=\"-324.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>_sampled_articles_inputs&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M719.39,-351.3C730.8,-351.3 742.61,-351.3 754.08,-351.3\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"753.9,-354.8 763.9,-351.3 753.9,-347.8 753.9,-354.8\"/>\n",
       "</g>\n",
       "<!-- _load_into_lancedb_inputs -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>_load_into_lancedb_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1130.75,-126.1 937.15,-126.1 937.15,-60.5 1130.75,-60.5 1130.75,-126.1\"/>\n",
       "<text text-anchor=\"start\" x=\"951.83\" y=\"-98\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">table_name</text>\n",
       "<text text-anchor=\"start\" x=\"1064.95\" y=\"-98\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "<text text-anchor=\"start\" x=\"960.45\" y=\"-77\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">db_client</text>\n",
       "<text text-anchor=\"start\" x=\"1028.95\" y=\"-77\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">DBConnection</text>\n",
       "</g>\n",
       "<!-- _load_into_lancedb_inputs&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>_load_into_lancedb_inputs&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1126.96,-126.58C1134.17,-129.19 1141.5,-131.84 1148.83,-134.49\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1147.55,-137.75 1158.14,-137.86 1149.93,-131.17 1147.55,-137.75\"/>\n",
       "</g>\n",
       "<!-- _retriever_inputs -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>_retriever_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"726.3,-44.6 563.2,-44.6 563.2,0 726.3,0 726.3,-44.6\"/>\n",
       "<text text-anchor=\"start\" x=\"578\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">retriever_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"696.5\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- _retriever_inputs&#45;&gt;retriever -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>_retriever_inputs&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M704.72,-44.94C711.99,-47.75 719.32,-50.59 726.3,-53.3 732.23,-55.61 738.35,-57.99 744.51,-60.4\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"742.89,-63.52 753.48,-63.91 745.44,-57 742.89,-63.52\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"150.67,-530.6 96.67,-530.6 96.67,-494 150.67,-494 150.67,-530.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-506.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.1,-585.6C146.1,-585.6 101.25,-585.6 101.25,-585.6 95.25,-585.6 89.25,-579.6 89.25,-573.6 89.25,-573.6 89.25,-561 89.25,-561 89.25,-555 95.25,-549 101.25,-549 101.25,-549 146.1,-549 146.1,-549 152.1,-549 158.1,-555 158.1,-561 158.1,-561 158.1,-573.6 158.1,-573.6 158.1,-579.6 152.1,-585.6 146.1,-585.6\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-561.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "<!-- materializer -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>materializer</title>\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M169.72,-640.84C169.72,-642.87 149.08,-644.51 123.67,-644.51 98.27,-644.51 77.62,-642.87 77.62,-640.84 77.62,-640.84 77.62,-607.76 77.62,-607.76 77.62,-605.73 98.27,-604.09 123.67,-604.09 149.08,-604.09 169.72,-605.73 169.72,-607.76 169.72,-607.76 169.72,-640.84 169.72,-640.84\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M169.72,-640.84C169.72,-638.81 149.08,-637.16 123.67,-637.16 98.27,-637.16 77.62,-638.81 77.62,-640.84\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-618.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">materializer</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x155e32770>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%incr_cell_to_module ner_module 4 --display\n",
    "from datasets.formatting.formatting import LazyBatch\n",
    "from typing import Union\n",
    "\n",
    "def _extract_named_entities_text(\n",
    "    title_text_batch: Union[LazyBatch, list[str]], _ner_pipeline\n",
    ") -> list[list[str]]:\n",
    "    \"\"\"Helper function to extract named entities given a batch of text.\"\"\"\n",
    "    # extract named entities using the NER pipeline\n",
    "    extracted_batch = _ner_pipeline(title_text_batch)\n",
    "    # this should be extracted_batch = dataset.map(ner_pipeline)\n",
    "    entities = []\n",
    "    # loop through the results and only select the entity names\n",
    "    for text in extracted_batch:\n",
    "        ne = [entity[\"word\"] for entity in text]\n",
    "        entities.append(ne)\n",
    "    _named_entities = [list(set(entity)) for entity in entities]\n",
    "    return _named_entities\n",
    "\n",
    "\n",
    "def _batch_map(dataset: LazyBatch, _retriever, _ner_pipeline) -> dict:\n",
    "    \"\"\"Helper function to created the embedding vectors and extract named entities\"\"\"\n",
    "    title_text_list = dataset[\"title_text\"]\n",
    "    emb = _retriever.encode(title_text_list)\n",
    "    _named_entities = _extract_named_entities_text(title_text_list, _ner_pipeline)\n",
    "    return {\n",
    "        \"vector\": emb,\n",
    "        \"named_entities\": _named_entities,\n",
    "    }\n",
    "\n",
    "\n",
    "def columns_of_interest() -> list[str]:\n",
    "    \"\"\"The columns we expect to pull from the dataset to be saved to lancedb\"\"\"\n",
    "    return [\"vector\", \"named_entities\", \"title\", \"url\", \"authors\", \"timestamp\", \"tags\"]\n",
    "\n",
    "\n",
    "@save_to.lancedb(\n",
    "    db_client=source(\"db_client\"),\n",
    "    table_name=source(\"table_name\"),\n",
    "    columns_to_write=source(\"columns_of_interest\"),\n",
    "    output_name_=\"load_into_lancedb\",\n",
    ")\n",
    "def final_dataset(\n",
    "    sampled_articles: Dataset,\n",
    "    retriever: SentenceTransformer,\n",
    "    ner_pipeline: base.Pipeline,\n",
    ") -> Dataset:\n",
    "    \"\"\"The final dataset to be pushed to lancedb.\n",
    "\n",
    "    This adds two columns:\n",
    "\n",
    "     - vector -- the vector embedding\n",
    "     - named_entities -- the names of entities extracted from the text\n",
    "    \"\"\"\n",
    "    # goes over the data in batches so that the GPU can be properly utilized.\n",
    "    final_ds = sampled_articles.map(\n",
    "        _batch_map,\n",
    "        batched=True,\n",
    "        fn_kwargs={\"_retriever\": retriever, \"_ner_pipeline\": ner_pipeline},\n",
    "        desc=\"extracting entities\",\n",
    "    )\n",
    "    return final_ds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9663faa6",
   "metadata": {},
   "source": [
    "# Load data into lancedb\n",
    "\n",
    "With our processing pipeline now ready, let's load some data into lancedb.\n",
    "\n",
    "We'll do this by instantiating a driver to execute our pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "91890b7e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"1379pt\" height=\"713pt\"\n",
       " viewBox=\"0.00 0.00 1379.10 712.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 708.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-708.8 1375.1,-708.8 1375.1,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"69.62,-507.8 69.62,-696.8 177.72,-696.8 177.72,-507.8 69.62,-507.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-679.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>medium_articles.load_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M235.35,-497.6C235.35,-497.6 12,-497.6 12,-497.6 6,-497.6 0,-491.6 0,-485.6 0,-485.6 0,-446 0,-446 0,-440 6,-434 12,-434 12,-434 235.35,-434 235.35,-434 241.35,-434 247.35,-440 247.35,-446 247.35,-446 247.35,-485.6 247.35,-485.6 247.35,-491.6 241.35,-497.6 235.35,-497.6\"/>\n",
       "<text text-anchor=\"start\" x=\"10.8\" y=\"-474.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.load_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"106.8\" y=\"-446.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Tuple</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>medium_articles.select_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M522.2,-497.6C522.2,-497.6 288.35,-497.6 288.35,-497.6 282.35,-497.6 276.35,-491.6 276.35,-485.6 276.35,-485.6 276.35,-446 276.35,-446 276.35,-440 282.35,-434 288.35,-434 288.35,-434 522.2,-434 522.2,-434 528.2,-434 534.2,-440 534.2,-446 534.2,-446 534.2,-485.6 534.2,-485.6 534.2,-491.6 528.2,-497.6 522.2,-497.6\"/>\n",
       "<text text-anchor=\"start\" x=\"287.15\" y=\"-474.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.select_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"382.03\" y=\"-446.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M247.72,-465.8C253.29,-465.8 258.89,-465.8 264.49,-465.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"264.47,-469.3 274.47,-465.8 264.47,-462.3 264.47,-469.3\"/>\n",
       "</g>\n",
       "<!-- medium_articles -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>medium_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M697.55,-497.6C697.55,-497.6 591.95,-497.6 591.95,-497.6 585.95,-497.6 579.95,-491.6 579.95,-485.6 579.95,-485.6 579.95,-446 579.95,-446 579.95,-440 585.95,-434 591.95,-434 591.95,-434 697.55,-434 697.55,-434 703.55,-434 709.55,-440 709.55,-446 709.55,-446 709.55,-485.6 709.55,-485.6 709.55,-491.6 703.55,-497.6 697.55,-497.6\"/>\n",
       "<text text-anchor=\"start\" x=\"590.75\" y=\"-474.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles</text>\n",
       "<text text-anchor=\"start\" x=\"621.5\" y=\"-446.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset&#45;&gt;medium_articles -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>medium_articles.select_data.dataset&#45;&gt;medium_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M534.4,-465.8C546.02,-465.8 557.53,-465.8 568.46,-465.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"568.12,-469.3 578.12,-465.8 568.12,-462.3 568.12,-469.3\"/>\n",
       "</g>\n",
       "<!-- retriever -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>retriever</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M896.15,-147.6C896.15,-147.6 767.3,-147.6 767.3,-147.6 761.3,-147.6 755.3,-141.6 755.3,-135.6 755.3,-135.6 755.3,-96 755.3,-96 755.3,-90 761.3,-84 767.3,-84 767.3,-84 896.15,-84 896.15,-84 902.15,-84 908.15,-90 908.15,-96 908.15,-96 908.15,-135.6 908.15,-135.6 908.15,-141.6 902.15,-147.6 896.15,-147.6\"/>\n",
       "<text text-anchor=\"start\" x=\"804.73\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">retriever</text>\n",
       "<text text-anchor=\"start\" x=\"766.1\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">SentenceTransformer</text>\n",
       "</g>\n",
       "<!-- final_dataset -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>final_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1074,-229.6C1074,-229.6 993.9,-229.6 993.9,-229.6 987.9,-229.6 981.9,-223.6 981.9,-217.6 981.9,-217.6 981.9,-178 981.9,-178 981.9,-172 987.9,-166 993.9,-166 993.9,-166 1074,-166 1074,-166 1080,-166 1086,-172 1086,-178 1086,-178 1086,-217.6 1086,-217.6 1086,-223.6 1080,-229.6 1074,-229.6\"/>\n",
       "<text text-anchor=\"start\" x=\"992.7\" y=\"-206.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">final_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"1010.7\" y=\"-178.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- retriever&#45;&gt;final_dataset -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>retriever&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M908.42,-146.8C929.07,-155.26 951.2,-164.32 970.99,-172.43\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"969.43,-175.57 980.01,-176.12 972.08,-169.09 969.43,-175.57\"/>\n",
       "</g>\n",
       "<!-- sampled_articles -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>sampled_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M885.65,-404.6C885.65,-404.6 777.8,-404.6 777.8,-404.6 771.8,-404.6 765.8,-398.6 765.8,-392.6 765.8,-392.6 765.8,-353 765.8,-353 765.8,-347 771.8,-341 777.8,-341 777.8,-341 885.65,-341 885.65,-341 891.65,-341 897.65,-347 897.65,-353 897.65,-353 897.65,-392.6 897.65,-392.6 897.65,-398.6 891.65,-404.6 885.65,-404.6\"/>\n",
       "<text text-anchor=\"start\" x=\"776.6\" y=\"-381.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">sampled_articles</text>\n",
       "<text text-anchor=\"start\" x=\"808.48\" y=\"-353.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- sampled_articles&#45;&gt;final_dataset -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>sampled_articles&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M869.45,-340.7C902.79,-311.56 952.07,-268.49 987.8,-237.26\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"989.75,-240.2 994.98,-230.99 985.15,-234.93 989.75,-240.2\"/>\n",
       "</g>\n",
       "<!-- columns_of_interest -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>columns_of_interest</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1099.12,-147.6C1099.12,-147.6 968.78,-147.6 968.78,-147.6 962.78,-147.6 956.78,-141.6 956.78,-135.6 956.78,-135.6 956.78,-96 956.78,-96 956.78,-90 962.78,-84 968.78,-84 968.78,-84 1099.12,-84 1099.12,-84 1105.12,-84 1111.12,-90 1111.12,-96 1111.12,-96 1111.12,-135.6 1111.12,-135.6 1111.12,-141.6 1105.12,-147.6 1099.12,-147.6\"/>\n",
       "<text text-anchor=\"start\" x=\"967.58\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">columns_of_interest</text>\n",
       "<text text-anchor=\"start\" x=\"1025.7\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">list</text>\n",
       "</g>\n",
       "<!-- load_into_lancedb -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>load_into_lancedb</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1371.1,-151.57C1371.1,-155.96 1323.73,-159.53 1265.43,-159.53 1207.12,-159.53 1159.75,-155.96 1159.75,-151.57 1159.75,-151.57 1159.75,-80.03 1159.75,-80.03 1159.75,-75.64 1207.12,-72.08 1265.43,-72.08 1323.73,-72.08 1371.1,-75.64 1371.1,-80.03 1371.1,-80.03 1371.1,-151.57 1371.1,-151.57\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1371.1,-151.57C1371.1,-147.19 1323.73,-143.62 1265.43,-143.62 1207.12,-143.62 1159.75,-147.19 1159.75,-151.57\"/>\n",
       "<text text-anchor=\"start\" x=\"1206.18\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">load_into_lancedb</text>\n",
       "<text text-anchor=\"start\" x=\"1170.55\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">HuggingFaceDSLanceDBSaver</text>\n",
       "</g>\n",
       "<!-- columns_of_interest&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>columns_of_interest&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1111.36,-115.8C1123.14,-115.8 1135.53,-115.8 1147.94,-115.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1147.82,-119.3 1157.82,-115.8 1147.82,-112.3 1147.82,-119.3\"/>\n",
       "</g>\n",
       "<!-- model -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M695.68,-311.6C695.68,-311.6 593.83,-311.6 593.83,-311.6 587.83,-311.6 581.83,-305.6 581.83,-299.6 581.83,-299.6 581.83,-260 581.83,-260 581.83,-254 587.83,-248 593.83,-248 593.83,-248 695.68,-248 695.68,-248 701.68,-248 707.68,-254 707.68,-260 707.68,-260 707.68,-299.6 707.68,-299.6 707.68,-305.6 701.68,-311.6 695.68,-311.6\"/>\n",
       "<text text-anchor=\"start\" x=\"624.5\" y=\"-288.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model</text>\n",
       "<text text-anchor=\"start\" x=\"592.63\" y=\"-260.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedModel</text>\n",
       "</g>\n",
       "<!-- ner_pipeline -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>ner_pipeline</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M870.28,-229.6C870.28,-229.6 793.18,-229.6 793.18,-229.6 787.18,-229.6 781.18,-223.6 781.18,-217.6 781.18,-217.6 781.18,-178 781.18,-178 781.18,-172 787.18,-166 793.18,-166 793.18,-166 870.28,-166 870.28,-166 876.28,-166 882.28,-172 882.28,-178 882.28,-178 882.28,-217.6 882.28,-217.6 882.28,-223.6 876.28,-229.6 870.28,-229.6\"/>\n",
       "<text text-anchor=\"start\" x=\"791.98\" y=\"-206.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">ner_pipeline</text>\n",
       "<text text-anchor=\"start\" x=\"807.73\" y=\"-178.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Pipeline</text>\n",
       "</g>\n",
       "<!-- model&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>model&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M708.12,-252.15C728.21,-243.25 750.43,-233.39 770.37,-224.55\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"771.78,-227.76 779.5,-220.51 768.94,-221.36 771.78,-227.76\"/>\n",
       "</g>\n",
       "<!-- medium_articles&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>medium_articles&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M709.93,-433.53C724.99,-425.96 741.17,-417.83 756.62,-410.06\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"757.95,-413.31 765.31,-405.69 754.8,-407.05 757.95,-413.31\"/>\n",
       "</g>\n",
       "<!-- final_dataset&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>final_dataset&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1086.45,-179.4C1105.07,-172.74 1126.93,-164.93 1148.83,-157.11\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1149.9,-160.44 1158.14,-153.78 1147.55,-153.85 1149.9,-160.44\"/>\n",
       "</g>\n",
       "<!-- tokenizer -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>tokenizer</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M705.8,-229.6C705.8,-229.6 583.7,-229.6 583.7,-229.6 577.7,-229.6 571.7,-223.6 571.7,-217.6 571.7,-217.6 571.7,-178 571.7,-178 571.7,-172 577.7,-166 583.7,-166 583.7,-166 705.8,-166 705.8,-166 711.8,-166 717.8,-172 717.8,-178 717.8,-178 717.8,-217.6 717.8,-217.6 717.8,-223.6 711.8,-229.6 705.8,-229.6\"/>\n",
       "<text text-anchor=\"start\" x=\"615.13\" y=\"-206.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">tokenizer</text>\n",
       "<text text-anchor=\"start\" x=\"582.5\" y=\"-178.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedTokenizer</text>\n",
       "</g>\n",
       "<!-- tokenizer&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>tokenizer&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M718.06,-197.8C735.1,-197.8 753.11,-197.8 769.6,-197.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"769.23,-201.3 779.23,-197.8 769.23,-194.3 769.23,-201.3\"/>\n",
       "</g>\n",
       "<!-- NER_model_id -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>NER_model_id</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M452.08,-270.6C452.08,-270.6 358.48,-270.6 358.48,-270.6 352.48,-270.6 346.48,-264.6 346.48,-258.6 346.48,-258.6 346.48,-219 346.48,-219 346.48,-213 352.48,-207 358.48,-207 358.48,-207 452.08,-207 452.08,-207 458.08,-207 464.08,-213 464.08,-219 464.08,-219 464.08,-258.6 464.08,-258.6 464.08,-264.6 458.08,-270.6 452.08,-270.6\"/>\n",
       "<text text-anchor=\"start\" x=\"357.28\" y=\"-247.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">NER_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"397.78\" y=\"-219.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;model -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M464.37,-248.83C496.32,-254.35 536.28,-261.25 570.37,-267.13\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"569.75,-270.57 580.2,-268.83 570.94,-263.68 569.75,-270.57\"/>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;tokenizer -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;tokenizer</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M464.37,-228.77C493.23,-223.79 528.63,-217.67 560.34,-212.2\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"560.63,-215.7 569.89,-210.55 559.44,-208.8 560.63,-215.7\"/>\n",
       "</g>\n",
       "<!-- device -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>device</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M664.55,-147.6C664.55,-147.6 624.95,-147.6 624.95,-147.6 618.95,-147.6 612.95,-141.6 612.95,-135.6 612.95,-135.6 612.95,-96 612.95,-96 612.95,-90 618.95,-84 624.95,-84 624.95,-84 664.55,-84 664.55,-84 670.55,-84 676.55,-90 676.55,-96 676.55,-96 676.55,-135.6 676.55,-135.6 676.55,-141.6 670.55,-147.6 664.55,-147.6\"/>\n",
       "<text text-anchor=\"start\" x=\"623.75\" y=\"-124.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">device</text>\n",
       "<text text-anchor=\"start\" x=\"637.25\" y=\"-96.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- device&#45;&gt;retriever -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>device&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M676.71,-115.8C695.18,-115.8 719.64,-115.8 743.61,-115.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"743.42,-119.3 753.42,-115.8 743.42,-112.3 743.42,-119.3\"/>\n",
       "</g>\n",
       "<!-- device&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>device&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M676.71,-129.52C702.21,-140.83 739.14,-157.2 770.44,-171.08\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"768.87,-174.21 779.43,-175.06 771.71,-167.81 768.87,-174.21\"/>\n",
       "</g>\n",
       "<!-- ner_pipeline&#45;&gt;final_dataset -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>ner_pipeline&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M882.73,-197.8C909.2,-197.8 941.97,-197.8 970.14,-197.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"969.98,-201.3 979.98,-197.8 969.98,-194.3 969.98,-201.3\"/>\n",
       "</g>\n",
       "<!-- _retriever_inputs -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>_retriever_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"726.3,-66.1 563.2,-66.1 563.2,-21.5 726.3,-21.5 726.3,-66.1\"/>\n",
       "<text text-anchor=\"start\" x=\"578\" y=\"-38\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">retriever_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"696.5\" y=\"-38\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- _retriever_inputs&#45;&gt;retriever -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>_retriever_inputs&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M704.72,-66.44C711.99,-69.25 719.32,-72.09 726.3,-74.8 732.23,-77.11 738.35,-79.49 744.51,-81.9\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"742.89,-85.02 753.48,-85.41 745.44,-78.5 742.89,-85.02\"/>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>_sampled_articles_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"719.05,-416.1 570.45,-416.1 570.45,-329.5 719.05,-329.5 719.05,-416.1\"/>\n",
       "<text text-anchor=\"start\" x=\"596.88\" y=\"-388\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sample_size</text>\n",
       "<text text-anchor=\"start\" x=\"690.13\" y=\"-388\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"593.13\" y=\"-367\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">random_state</text>\n",
       "<text text-anchor=\"start\" x=\"690.13\" y=\"-367\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"584.88\" y=\"-346\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">max_text_length</text>\n",
       "<text text-anchor=\"start\" x=\"690.13\" y=\"-346\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>_sampled_articles_inputs&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M719.39,-372.8C730.8,-372.8 742.61,-372.8 754.08,-372.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"753.9,-376.3 763.9,-372.8 753.9,-369.3 753.9,-376.3\"/>\n",
       "</g>\n",
       "<!-- _load_into_lancedb_inputs -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>_load_into_lancedb_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1130.75,-65.6 937.15,-65.6 937.15,0 1130.75,0 1130.75,-65.6\"/>\n",
       "<text text-anchor=\"start\" x=\"951.83\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">table_name</text>\n",
       "<text text-anchor=\"start\" x=\"1064.95\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "<text text-anchor=\"start\" x=\"960.45\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">db_client</text>\n",
       "<text text-anchor=\"start\" x=\"1028.95\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">DBConnection</text>\n",
       "</g>\n",
       "<!-- _load_into_lancedb_inputs&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>_load_into_lancedb_inputs&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1126.96,-66.08C1134.17,-68.69 1141.5,-71.34 1148.83,-73.99\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1147.55,-77.25 1158.14,-77.36 1149.93,-70.67 1147.55,-77.25\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"150.67,-552.1 96.67,-552.1 96.67,-515.5 150.67,-515.5 150.67,-552.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-528\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.1,-607.1C146.1,-607.1 101.25,-607.1 101.25,-607.1 95.25,-607.1 89.25,-601.1 89.25,-595.1 89.25,-595.1 89.25,-582.5 89.25,-582.5 89.25,-576.5 95.25,-570.5 101.25,-570.5 101.25,-570.5 146.1,-570.5 146.1,-570.5 152.1,-570.5 158.1,-576.5 158.1,-582.5 158.1,-582.5 158.1,-595.1 158.1,-595.1 158.1,-601.1 152.1,-607.1 146.1,-607.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-583\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "<!-- materializer -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>materializer</title>\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M169.72,-662.34C169.72,-664.37 149.08,-666.01 123.67,-666.01 98.27,-666.01 77.62,-664.37 77.62,-662.34 77.62,-662.34 77.62,-629.26 77.62,-629.26 77.62,-627.23 98.27,-625.59 123.67,-625.59 149.08,-625.59 169.72,-627.23 169.72,-629.26 169.72,-629.26 169.72,-662.34 169.72,-662.34\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M169.72,-662.34C169.72,-660.31 149.08,-658.66 123.67,-658.66 98.27,-658.66 77.62,-660.31 77.62,-662.34\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-640\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">materializer</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<hamilton.driver.Driver at 0x1529067a0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from hamilton import driver, lifecycle\n",
    "dr = (\n",
    "    driver.Builder()\n",
    "    .with_config({})\n",
    "    .with_modules(ner_module)\n",
    "    .with_adapters(lifecycle.PrintLn())\n",
    "    .build()\n",
    ")\n",
    "dr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cbeb3ff7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing node: columns_of_interest.\n",
      "Finished debugging node: columns_of_interest in 609μs. Status: Success.\n",
      "Executing node: medium_articles.load_data.dataset.\n",
      "Finished debugging node: medium_articles.load_data.dataset in 1.87s. Status: Success.\n",
      "Executing node: medium_articles.select_data.dataset.\n",
      "Finished debugging node: medium_articles.select_data.dataset in 17.9μs. Status: Success.\n",
      "Executing node: medium_articles.\n",
      "Finished debugging node: medium_articles in 23.1μs. Status: Success.\n",
      "Executing node: sampled_articles.\n",
      "Finished debugging node: sampled_articles in 30ms. Status: Success.\n",
      "Executing node: device.\n",
      "Finished debugging node: device in 40.1μs. Status: Success.\n",
      "Executing node: retriever.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stefankrawczyk/.pyenv/versions/3.10.4/envs/ner-example-py310/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished debugging node: retriever in 1.11s. Status: Success.\n",
      "Executing node: NER_model_id.\n",
      "Finished debugging node: NER_model_id in 17.9μs. Status: Success.\n",
      "Executing node: model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished debugging node: model in 340ms. Status: Success.\n",
      "Executing node: tokenizer.\n",
      "Finished debugging node: tokenizer in 158ms. Status: Success.\n",
      "Executing node: ner_pipeline.\n",
      "Finished debugging node: ner_pipeline in 2.52ms. Status: Success.\n",
      "Executing node: final_dataset.\n",
      "Finished debugging node: final_dataset in 5.96s. Status: Success.\n",
      "Executing node: load_into_lancedb.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5f7812bc80514104978704f23e3a50a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "writing to lancedb table medium_docs:   0%|          | 0/104 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished debugging node: load_into_lancedb in 215ms. Status: Success.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'load_into_lancedb': {'db_meta': {'table_name': 'medium_docs'},\n",
       "  'dataset_metadata': {'rows': 104,\n",
       "   'columns': ['title',\n",
       "    'text',\n",
       "    'url',\n",
       "    'authors',\n",
       "    'timestamp',\n",
       "    'tags',\n",
       "    'title_text',\n",
       "    'vector',\n",
       "    'named_entities'],\n",
       "   'size_in_bytes': 2087095068,\n",
       "   'features': {'title': {'dtype': 'string', '_type': 'Value'},\n",
       "    'text': {'dtype': 'string', '_type': 'Value'},\n",
       "    'url': {'dtype': 'string', '_type': 'Value'},\n",
       "    'authors': {'dtype': 'string', '_type': 'Value'},\n",
       "    'timestamp': {'dtype': 'string', '_type': 'Value'},\n",
       "    'tags': {'dtype': 'string', '_type': 'Value'},\n",
       "    'title_text': {'dtype': 'string', '_type': 'Value'},\n",
       "    'vector': {'feature': {'dtype': 'float32', '_type': 'Value'},\n",
       "     '_type': 'Sequence'},\n",
       "    'named_entities': {'feature': {'dtype': 'string', '_type': 'Value'},\n",
       "     '_type': 'Sequence'}}}}}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# now we execute it - we specify the output that we want.\n",
    "import lancedb\n",
    "table_name = \"medium_docs\"\n",
    "db_client =  lancedb.connect(\"./.lancedb\")\n",
    "\n",
    "results = dr.execute(\n",
    "    [\"load_into_lancedb\"],\n",
    "    inputs={\"table_name\": table_name, \"db_client\": db_client},\n",
    ")\n",
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d35b4c6",
   "metadata": {},
   "source": [
    "# Querying for results\n",
    "\n",
    "We can now query the DB. To do that let's create an inference portion of our pipeline.\n",
    "\n",
    "We'll write a function to extract entities from the query, and then construct the appropriate lancedb query to\n",
    "filter results only if there are named entities in common."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5575f45b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"1420pt\" height=\"779pt\"\n",
       " viewBox=\"0.00 0.00 1419.85 778.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 774.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-774.8 1415.85,-774.8 1415.85,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"69.62,-573.8 69.62,-762.8 177.72,-762.8 177.72,-573.8 69.62,-573.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-745.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- columns_of_interest -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>columns_of_interest</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1139.87,-452.6C1139.87,-452.6 1009.52,-452.6 1009.52,-452.6 1003.52,-452.6 997.52,-446.6 997.52,-440.6 997.52,-440.6 997.52,-401 997.52,-401 997.52,-395 1003.52,-389 1009.52,-389 1009.52,-389 1139.87,-389 1139.87,-389 1145.87,-389 1151.87,-395 1151.87,-401 1151.87,-401 1151.87,-440.6 1151.87,-440.6 1151.87,-446.6 1145.87,-452.6 1139.87,-452.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1008.32\" y=\"-429.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">columns_of_interest</text>\n",
       "<text text-anchor=\"start\" x=\"1066.45\" y=\"-401.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">list</text>\n",
       "</g>\n",
       "<!-- load_into_lancedb -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>load_into_lancedb</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1411.85,-456.57C1411.85,-460.96 1364.48,-464.52 1306.17,-464.52 1247.87,-464.52 1200.5,-460.96 1200.5,-456.57 1200.5,-456.57 1200.5,-385.02 1200.5,-385.02 1200.5,-380.64 1247.87,-377.07 1306.17,-377.07 1364.48,-377.07 1411.85,-380.64 1411.85,-385.02 1411.85,-385.02 1411.85,-456.57 1411.85,-456.57\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1411.85,-456.57C1411.85,-452.19 1364.48,-448.62 1306.17,-448.62 1247.87,-448.62 1200.5,-452.19 1200.5,-456.57\"/>\n",
       "<text text-anchor=\"start\" x=\"1246.92\" y=\"-429.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">load_into_lancedb</text>\n",
       "<text text-anchor=\"start\" x=\"1211.3\" y=\"-401.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">HuggingFaceDSLanceDBSaver</text>\n",
       "</g>\n",
       "<!-- columns_of_interest&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>columns_of_interest&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1152.11,-420.8C1163.89,-420.8 1176.28,-420.8 1188.69,-420.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1188.57,-424.3 1198.57,-420.8 1188.57,-417.3 1188.57,-424.3\"/>\n",
       "</g>\n",
       "<!-- final_dataset -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>final_dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1114.75,-370.6C1114.75,-370.6 1034.65,-370.6 1034.65,-370.6 1028.65,-370.6 1022.65,-364.6 1022.65,-358.6 1022.65,-358.6 1022.65,-319 1022.65,-319 1022.65,-313 1028.65,-307 1034.65,-307 1034.65,-307 1114.75,-307 1114.75,-307 1120.75,-307 1126.75,-313 1126.75,-319 1126.75,-319 1126.75,-358.6 1126.75,-358.6 1126.75,-364.6 1120.75,-370.6 1114.75,-370.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1033.45\" y=\"-347.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">final_dataset</text>\n",
       "<text text-anchor=\"start\" x=\"1051.45\" y=\"-319.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- final_dataset&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>final_dataset&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1127.2,-357.2C1145.82,-363.86 1167.68,-371.67 1189.58,-379.49\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1188.3,-382.75 1198.89,-382.82 1190.65,-376.16 1188.3,-382.75\"/>\n",
       "</g>\n",
       "<!-- lancedb_table -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>lancedb_table</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1118.5,-64.6C1118.5,-64.6 1030.9,-64.6 1030.9,-64.6 1024.9,-64.6 1018.9,-58.6 1018.9,-52.6 1018.9,-52.6 1018.9,-13 1018.9,-13 1018.9,-7 1024.9,-1 1030.9,-1 1030.9,-1 1118.5,-1 1118.5,-1 1124.5,-1 1130.5,-7 1130.5,-13 1130.5,-13 1130.5,-52.6 1130.5,-52.6 1130.5,-58.6 1124.5,-64.6 1118.5,-64.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1029.7\" y=\"-41.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lancedb_table</text>\n",
       "<text text-anchor=\"start\" x=\"1058.57\" y=\"-13.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Table</text>\n",
       "</g>\n",
       "<!-- lancedb_result -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>lancedb_result</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1352.6,-204.6C1352.6,-204.6 1259.75,-204.6 1259.75,-204.6 1253.75,-204.6 1247.75,-198.6 1247.75,-192.6 1247.75,-192.6 1247.75,-153 1247.75,-153 1247.75,-147 1253.75,-141 1259.75,-141 1259.75,-141 1352.6,-141 1352.6,-141 1358.6,-141 1364.6,-147 1364.6,-153 1364.6,-153 1364.6,-192.6 1364.6,-192.6 1364.6,-198.6 1358.6,-204.6 1352.6,-204.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1258.55\" y=\"-181.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">lancedb_result</text>\n",
       "<text text-anchor=\"start\" x=\"1295.67\" y=\"-153.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- lancedb_table&#45;&gt;lancedb_result -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>lancedb_table&#45;&gt;lancedb_result</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1130.98,-53.86C1144.53,-59.72 1158.79,-66.51 1171.5,-73.8 1201.73,-91.14 1233.16,-114.14 1257.95,-133.6\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1255.57,-136.18 1265.58,-139.65 1259.92,-130.69 1255.57,-136.18\"/>\n",
       "</g>\n",
       "<!-- named_entities -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>named_entities</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1123,-250.6C1123,-250.6 1026.4,-250.6 1026.4,-250.6 1020.4,-250.6 1014.4,-244.6 1014.4,-238.6 1014.4,-238.6 1014.4,-199 1014.4,-199 1014.4,-193 1020.4,-187 1026.4,-187 1026.4,-187 1123,-187 1123,-187 1129,-187 1135,-193 1135,-199 1135,-199 1135,-238.6 1135,-238.6 1135,-244.6 1129,-250.6 1123,-250.6\"/>\n",
       "<text text-anchor=\"start\" x=\"1025.2\" y=\"-227.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">named_entities</text>\n",
       "<text text-anchor=\"start\" x=\"1066.45\" y=\"-199.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">list</text>\n",
       "</g>\n",
       "<!-- named_entities&#45;&gt;lancedb_result -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>named_entities&#45;&gt;lancedb_result</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1135.39,-206.83C1166.28,-200.64 1204.14,-193.05 1236.26,-186.61\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1236.76,-190.08 1245.88,-184.69 1235.39,-183.22 1236.76,-190.08\"/>\n",
       "</g>\n",
       "<!-- tokenizer -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>tokenizer</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M705.8,-232.6C705.8,-232.6 583.7,-232.6 583.7,-232.6 577.7,-232.6 571.7,-226.6 571.7,-220.6 571.7,-220.6 571.7,-181 571.7,-181 571.7,-175 577.7,-169 583.7,-169 583.7,-169 705.8,-169 705.8,-169 711.8,-169 717.8,-175 717.8,-181 717.8,-181 717.8,-220.6 717.8,-220.6 717.8,-226.6 711.8,-232.6 705.8,-232.6\"/>\n",
       "<text text-anchor=\"start\" x=\"615.12\" y=\"-209.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">tokenizer</text>\n",
       "<text text-anchor=\"start\" x=\"582.5\" y=\"-181.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedTokenizer</text>\n",
       "</g>\n",
       "<!-- ner_pipeline -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>ner_pipeline</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M890.65,-259.6C890.65,-259.6 813.55,-259.6 813.55,-259.6 807.55,-259.6 801.55,-253.6 801.55,-247.6 801.55,-247.6 801.55,-208 801.55,-208 801.55,-202 807.55,-196 813.55,-196 813.55,-196 890.65,-196 890.65,-196 896.65,-196 902.65,-202 902.65,-208 902.65,-208 902.65,-247.6 902.65,-247.6 902.65,-253.6 896.65,-259.6 890.65,-259.6\"/>\n",
       "<text text-anchor=\"start\" x=\"812.35\" y=\"-236.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">ner_pipeline</text>\n",
       "<text text-anchor=\"start\" x=\"828.1\" y=\"-208.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Pipeline</text>\n",
       "</g>\n",
       "<!-- tokenizer&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge20\" class=\"edge\">\n",
       "<title>tokenizer&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M718.15,-210.32C741.66,-213.41 767.5,-216.81 790.09,-219.78\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"789.37,-223.21 799.74,-221.05 790.28,-216.27 789.37,-223.21\"/>\n",
       "</g>\n",
       "<!-- model -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M695.67,-150.6C695.67,-150.6 593.83,-150.6 593.83,-150.6 587.83,-150.6 581.83,-144.6 581.83,-138.6 581.83,-138.6 581.83,-99 581.83,-99 581.83,-93 587.83,-87 593.83,-87 593.83,-87 695.67,-87 695.67,-87 701.67,-87 707.67,-93 707.67,-99 707.67,-99 707.67,-138.6 707.67,-138.6 707.67,-144.6 701.67,-150.6 695.67,-150.6\"/>\n",
       "<text text-anchor=\"start\" x=\"624.5\" y=\"-127.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model</text>\n",
       "<text text-anchor=\"start\" x=\"592.62\" y=\"-99.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">PreTrainedModel</text>\n",
       "</g>\n",
       "<!-- model&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge19\" class=\"edge\">\n",
       "<title>model&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M707.91,-148.2C714.27,-151.86 720.51,-155.75 726.3,-159.8 740.73,-169.89 740.54,-177.19 755.3,-186.8 766.21,-193.9 778.55,-200.25 790.65,-205.7\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"789.24,-208.9 799.8,-209.66 792.02,-202.48 789.24,-208.9\"/>\n",
       "</g>\n",
       "<!-- sampled_articles -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>sampled_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M906.02,-470.6C906.02,-470.6 798.17,-470.6 798.17,-470.6 792.17,-470.6 786.17,-464.6 786.17,-458.6 786.17,-458.6 786.17,-419 786.17,-419 786.17,-413 792.17,-407 798.17,-407 798.17,-407 906.02,-407 906.02,-407 912.02,-407 918.02,-413 918.02,-419 918.02,-419 918.02,-458.6 918.02,-458.6 918.02,-464.6 912.02,-470.6 906.02,-470.6\"/>\n",
       "<text text-anchor=\"start\" x=\"796.97\" y=\"-447.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">sampled_articles</text>\n",
       "<text text-anchor=\"start\" x=\"828.85\" y=\"-419.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- sampled_articles&#45;&gt;final_dataset -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>sampled_articles&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M918.31,-407.41C937.44,-398.35 958.45,-388.56 977.9,-379.8 988.89,-374.85 1000.64,-369.71 1012.02,-364.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1013.14,-368.13 1020.95,-360.96 1010.38,-361.69 1013.14,-368.13\"/>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>medium_articles.load_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M235.35,-563.6C235.35,-563.6 12,-563.6 12,-563.6 6,-563.6 0,-557.6 0,-551.6 0,-551.6 0,-512 0,-512 0,-506 6,-500 12,-500 12,-500 235.35,-500 235.35,-500 241.35,-500 247.35,-506 247.35,-512 247.35,-512 247.35,-551.6 247.35,-551.6 247.35,-557.6 241.35,-563.6 235.35,-563.6\"/>\n",
       "<text text-anchor=\"start\" x=\"10.8\" y=\"-540.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.load_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"106.8\" y=\"-512.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Tuple</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>medium_articles.select_data.dataset</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M522.2,-563.6C522.2,-563.6 288.35,-563.6 288.35,-563.6 282.35,-563.6 276.35,-557.6 276.35,-551.6 276.35,-551.6 276.35,-512 276.35,-512 276.35,-506 282.35,-500 288.35,-500 288.35,-500 522.2,-500 522.2,-500 528.2,-500 534.2,-506 534.2,-512 534.2,-512 534.2,-551.6 534.2,-551.6 534.2,-557.6 528.2,-563.6 522.2,-563.6\"/>\n",
       "<text text-anchor=\"start\" x=\"287.15\" y=\"-540.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles.select_data.dataset</text>\n",
       "<text text-anchor=\"start\" x=\"382.02\" y=\"-512.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>medium_articles.load_data.dataset&#45;&gt;medium_articles.select_data.dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M247.72,-531.8C253.29,-531.8 258.89,-531.8 264.49,-531.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"264.47,-535.3 274.47,-531.8 264.47,-528.3 264.47,-535.3\"/>\n",
       "</g>\n",
       "<!-- medium_articles -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>medium_articles</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M697.55,-563.6C697.55,-563.6 591.95,-563.6 591.95,-563.6 585.95,-563.6 579.95,-557.6 579.95,-551.6 579.95,-551.6 579.95,-512 579.95,-512 579.95,-506 585.95,-500 591.95,-500 591.95,-500 697.55,-500 697.55,-500 703.55,-500 709.55,-506 709.55,-512 709.55,-512 709.55,-551.6 709.55,-551.6 709.55,-557.6 703.55,-563.6 697.55,-563.6\"/>\n",
       "<text text-anchor=\"start\" x=\"590.75\" y=\"-540.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">medium_articles</text>\n",
       "<text text-anchor=\"start\" x=\"621.5\" y=\"-512.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Dataset</text>\n",
       "</g>\n",
       "<!-- medium_articles.select_data.dataset&#45;&gt;medium_articles -->\n",
       "<g id=\"edge24\" class=\"edge\">\n",
       "<title>medium_articles.select_data.dataset&#45;&gt;medium_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M534.4,-531.8C546.02,-531.8 557.53,-531.8 568.46,-531.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"568.12,-535.3 578.12,-531.8 568.12,-528.3 568.12,-535.3\"/>\n",
       "</g>\n",
       "<!-- ner_pipeline&#45;&gt;final_dataset -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>ner_pipeline&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M902.88,-252.85C935.27,-269.14 977.73,-290.51 1012.22,-307.87\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1010.33,-310.83 1020.83,-312.2 1013.47,-304.58 1010.33,-310.83\"/>\n",
       "</g>\n",
       "<!-- ner_pipeline&#45;&gt;named_entities -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>ner_pipeline&#45;&gt;named_entities</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M902.88,-225.77C932.32,-224.57 970.1,-223.03 1002.63,-221.7\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1002.54,-225.21 1012.39,-221.3 1002.26,-218.21 1002.54,-225.21\"/>\n",
       "</g>\n",
       "<!-- retriever -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>retriever</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M916.52,-360.6C916.52,-360.6 787.67,-360.6 787.67,-360.6 781.67,-360.6 775.67,-354.6 775.67,-348.6 775.67,-348.6 775.67,-309 775.67,-309 775.67,-303 781.67,-297 787.67,-297 787.67,-297 916.52,-297 916.52,-297 922.52,-297 928.52,-303 928.52,-309 928.52,-309 928.52,-348.6 928.52,-348.6 928.52,-354.6 922.52,-360.6 916.52,-360.6\"/>\n",
       "<text text-anchor=\"start\" x=\"825.1\" y=\"-337.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">retriever</text>\n",
       "<text text-anchor=\"start\" x=\"786.47\" y=\"-309.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">SentenceTransformer</text>\n",
       "</g>\n",
       "<!-- retriever&#45;&gt;final_dataset -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>retriever&#45;&gt;final_dataset</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M929.02,-332.24C955.71,-333.45 985.39,-334.8 1010.89,-335.95\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1010.63,-339.44 1020.77,-336.4 1010.94,-332.45 1010.63,-339.44\"/>\n",
       "</g>\n",
       "<!-- retriever&#45;&gt;lancedb_result -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>retriever&#45;&gt;lancedb_result</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M928.85,-309.34C945.01,-305.37 961.99,-301.33 977.9,-297.8 1063.5,-278.78 1090.25,-292.76 1171.5,-259.8 1201.17,-247.76 1231.32,-228.85 1255.45,-211.72\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1257.45,-214.6 1263.5,-205.9 1253.35,-208.92 1257.45,-214.6\"/>\n",
       "</g>\n",
       "<!-- NER_model_id -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>NER_model_id</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M452.07,-191.6C452.07,-191.6 358.47,-191.6 358.47,-191.6 352.47,-191.6 346.47,-185.6 346.47,-179.6 346.47,-179.6 346.47,-140 346.47,-140 346.47,-134 352.47,-128 358.47,-128 358.47,-128 452.07,-128 452.07,-128 458.07,-128 464.07,-134 464.07,-140 464.07,-140 464.07,-179.6 464.07,-179.6 464.07,-185.6 458.07,-191.6 452.07,-191.6\"/>\n",
       "<text text-anchor=\"start\" x=\"357.27\" y=\"-168.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">NER_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"397.77\" y=\"-140.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;tokenizer -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;tokenizer</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M464.37,-169.83C493.23,-174.81 528.63,-180.93 560.34,-186.4\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"559.44,-189.8 569.89,-188.05 560.63,-182.9 559.44,-189.8\"/>\n",
       "</g>\n",
       "<!-- NER_model_id&#45;&gt;model -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>NER_model_id&#45;&gt;model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M464.37,-149.77C496.32,-144.25 536.28,-137.35 570.37,-131.47\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"570.94,-134.92 580.2,-129.77 569.75,-128.03 570.94,-134.92\"/>\n",
       "</g>\n",
       "<!-- device -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>device</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M664.55,-314.6C664.55,-314.6 624.95,-314.6 624.95,-314.6 618.95,-314.6 612.95,-308.6 612.95,-302.6 612.95,-302.6 612.95,-263 612.95,-263 612.95,-257 618.95,-251 624.95,-251 624.95,-251 664.55,-251 664.55,-251 670.55,-251 676.55,-257 676.55,-263 676.55,-263 676.55,-302.6 676.55,-302.6 676.55,-308.6 670.55,-314.6 664.55,-314.6\"/>\n",
       "<text text-anchor=\"start\" x=\"623.75\" y=\"-291.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">device</text>\n",
       "<text text-anchor=\"start\" x=\"637.25\" y=\"-263.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- device&#45;&gt;ner_pipeline -->\n",
       "<g id=\"edge21\" class=\"edge\">\n",
       "<title>device&#45;&gt;ner_pipeline</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M676.94,-274.45C706.96,-266.41 753.17,-254.03 790.45,-244.04\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"790.99,-247.52 799.74,-241.55 789.18,-240.76 790.99,-247.52\"/>\n",
       "</g>\n",
       "<!-- device&#45;&gt;retriever -->\n",
       "<g id=\"edge22\" class=\"edge\">\n",
       "<title>device&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M676.94,-289.79C700.32,-295.02 733.53,-302.46 764.69,-309.44\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"763.6,-312.79 774.13,-311.56 765.13,-305.96 763.6,-312.79\"/>\n",
       "</g>\n",
       "<!-- medium_articles&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>medium_articles&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M709.88,-502.76C730.61,-493.37 753.8,-482.87 775.29,-473.13\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"776.64,-476.36 784.31,-469.05 773.76,-469.99 776.64,-476.36\"/>\n",
       "</g>\n",
       "<!-- _lancedb_table_inputs -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>_lancedb_table_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"948.9,-65.6 755.3,-65.6 755.3,0 948.9,0 948.9,-65.6\"/>\n",
       "<text text-anchor=\"start\" x=\"769.97\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">table_name</text>\n",
       "<text text-anchor=\"start\" x=\"883.1\" y=\"-37.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "<text text-anchor=\"start\" x=\"778.6\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">db_client</text>\n",
       "<text text-anchor=\"start\" x=\"847.1\" y=\"-16.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">DBConnection</text>\n",
       "</g>\n",
       "<!-- _lancedb_table_inputs&#45;&gt;lancedb_table -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>_lancedb_table_inputs&#45;&gt;lancedb_table</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M949.22,-32.8C968.74,-32.8 988.86,-32.8 1007.09,-32.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1006.96,-36.3 1016.96,-32.8 1006.96,-29.3 1006.96,-36.3\"/>\n",
       "</g>\n",
       "<!-- _named_entities_inputs -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>_named_entities_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"894.27,-178.1 809.92,-178.1 809.92,-133.5 894.27,-133.5 894.27,-178.1\"/>\n",
       "<text text-anchor=\"start\" x=\"824.72\" y=\"-150\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">query</text>\n",
       "<text text-anchor=\"start\" x=\"864.47\" y=\"-150\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- _named_entities_inputs&#45;&gt;named_entities -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>_named_entities_inputs&#45;&gt;named_entities</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M894.59,-167.65C925.13,-176.37 967.39,-188.44 1003.19,-198.66\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1002.04,-201.97 1012.62,-201.36 1003.96,-195.24 1002.04,-201.97\"/>\n",
       "</g>\n",
       "<!-- _lancedb_result_inputs -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>_lancedb_result_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1127.5,-169.1 1021.9,-169.1 1021.9,-82.5 1127.5,-82.5 1127.5,-169.1\"/>\n",
       "<text text-anchor=\"start\" x=\"1036.7\" y=\"-141\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">prefilter</text>\n",
       "<text text-anchor=\"start\" x=\"1087.45\" y=\"-141\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n",
       "<text text-anchor=\"start\" x=\"1042.7\" y=\"-120\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">top_k</text>\n",
       "<text text-anchor=\"start\" x=\"1093.07\" y=\"-120\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"1042.32\" y=\"-99\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">query</text>\n",
       "<text text-anchor=\"start\" x=\"1092.7\" y=\"-99\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- _lancedb_result_inputs&#45;&gt;lancedb_result -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>_lancedb_result_inputs&#45;&gt;lancedb_result</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1127.77,-136.47C1159.87,-143.04 1201.41,-151.55 1236.22,-158.68\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1235.28,-162.06 1245.78,-160.63 1236.68,-155.2 1235.28,-162.06\"/>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs -->\n",
       "<g id=\"node20\" class=\"node\">\n",
       "<title>_sampled_articles_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"719.05,-482.1 570.45,-482.1 570.45,-395.5 719.05,-395.5 719.05,-482.1\"/>\n",
       "<text text-anchor=\"start\" x=\"596.88\" y=\"-454\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sample_size</text>\n",
       "<text text-anchor=\"start\" x=\"690.12\" y=\"-454\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"584.88\" y=\"-433\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">max_text_length</text>\n",
       "<text text-anchor=\"start\" x=\"690.12\" y=\"-433\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"593.12\" y=\"-412\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">random_state</text>\n",
       "<text text-anchor=\"start\" x=\"690.12\" y=\"-412\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _sampled_articles_inputs&#45;&gt;sampled_articles -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>_sampled_articles_inputs&#45;&gt;sampled_articles</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M719.31,-438.8C737.23,-438.8 756.45,-438.8 774.52,-438.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"774.38,-442.3 784.38,-438.8 774.38,-435.3 774.38,-442.3\"/>\n",
       "</g>\n",
       "<!-- _load_into_lancedb_inputs -->\n",
       "<g id=\"node21\" class=\"node\">\n",
       "<title>_load_into_lancedb_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1171.5,-536.6 977.9,-536.6 977.9,-471 1171.5,-471 1171.5,-536.6\"/>\n",
       "<text text-anchor=\"start\" x=\"992.57\" y=\"-508.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">table_name</text>\n",
       "<text text-anchor=\"start\" x=\"1105.7\" y=\"-508.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "<text text-anchor=\"start\" x=\"1001.2\" y=\"-487.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">db_client</text>\n",
       "<text text-anchor=\"start\" x=\"1069.7\" y=\"-487.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">DBConnection</text>\n",
       "</g>\n",
       "<!-- _load_into_lancedb_inputs&#45;&gt;load_into_lancedb -->\n",
       "<g id=\"edge18\" class=\"edge\">\n",
       "<title>_load_into_lancedb_inputs&#45;&gt;load_into_lancedb</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1167.71,-470.52C1174.92,-467.91 1182.25,-465.26 1189.58,-462.61\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1190.68,-465.93 1198.89,-459.24 1188.3,-459.35 1190.68,-465.93\"/>\n",
       "</g>\n",
       "<!-- _retriever_inputs -->\n",
       "<g id=\"node22\" class=\"node\">\n",
       "<title>_retriever_inputs</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"726.3,-377.1 563.2,-377.1 563.2,-332.5 726.3,-332.5 726.3,-377.1\"/>\n",
       "<text text-anchor=\"start\" x=\"578\" y=\"-349\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">retriever_model_id</text>\n",
       "<text text-anchor=\"start\" x=\"696.5\" y=\"-349\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">str</text>\n",
       "</g>\n",
       "<!-- _retriever_inputs&#45;&gt;retriever -->\n",
       "<g id=\"edge23\" class=\"edge\">\n",
       "<title>_retriever_inputs&#45;&gt;retriever</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M726.61,-344.56C738.9,-343.01 751.63,-341.39 764.04,-339.82\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"764.35,-343.31 773.83,-338.58 763.47,-336.37 764.35,-343.31\"/>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node23\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"150.67,-618.1 96.67,-618.1 96.67,-581.5 150.67,-581.5 150.67,-618.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-594\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node24\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M146.1,-673.1C146.1,-673.1 101.25,-673.1 101.25,-673.1 95.25,-673.1 89.25,-667.1 89.25,-661.1 89.25,-661.1 89.25,-648.5 89.25,-648.5 89.25,-642.5 95.25,-636.5 101.25,-636.5 101.25,-636.5 146.1,-636.5 146.1,-636.5 152.1,-636.5 158.1,-642.5 158.1,-648.5 158.1,-648.5 158.1,-661.1 158.1,-661.1 158.1,-667.1 152.1,-673.1 146.1,-673.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-649\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "<!-- materializer -->\n",
       "<g id=\"node25\" class=\"node\">\n",
       "<title>materializer</title>\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M169.72,-728.34C169.72,-730.37 149.08,-732.01 123.67,-732.01 98.27,-732.01 77.62,-730.37 77.62,-728.34 77.62,-728.34 77.62,-695.26 77.62,-695.26 77.62,-693.23 98.27,-691.59 123.67,-691.59 149.08,-691.59 169.72,-693.23 169.72,-695.26 169.72,-695.26 169.72,-728.34 169.72,-728.34\"/>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M169.72,-728.34C169.72,-726.31 149.08,-724.66 123.67,-724.66 98.27,-724.66 77.62,-726.31 77.62,-728.34\"/>\n",
       "<text text-anchor=\"middle\" x=\"123.67\" y=\"-706\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">materializer</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x15a09eb30>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%incr_cell_to_module ner_module 5 --display \n",
    "\n",
    "import lancedb\n",
    "import numpy as np\n",
    "\n",
    "def named_entities(query: str, ner_pipeline: base.Pipeline) -> list[str]:\n",
    "    \"\"\"The entities to extract from the query via the pipeline.\"\"\"\n",
    "    return _extract_named_entities_text([query], ner_pipeline)[0]\n",
    "\n",
    "def lancedb_table(db_client: lancedb.DBConnection, table_name: str = \"tw\") -> lancedb.table.Table:\n",
    "    \"\"\"Table to query against\"\"\"\n",
    "    tbl = db_client.open_table(table_name)\n",
    "    return tbl\n",
    "\n",
    "\n",
    "def lancedb_result(\n",
    "    query: str,\n",
    "    named_entities: list[str],\n",
    "    retriever: SentenceTransformer,\n",
    "    lancedb_table: lancedb.table.Table,\n",
    "    top_k: int = 10,\n",
    "    prefilter: bool = True,\n",
    ") -> dict:\n",
    "    \"\"\"Result of querying lancedb.\n",
    "\n",
    "    :param query: the query\n",
    "    :param named_entities: the named entities found in the query\n",
    "    :param retriever: the model to create the embedding from the query\n",
    "    :param lancedb_table: the lancedb table to query against\n",
    "    :param top_k: number of top results\n",
    "    :param prefilter: whether to prefilter results before cosine distance\n",
    "    :return: dictionary result\n",
    "    \"\"\"\n",
    "    # create embeddings for the query\n",
    "    query_vector = np.array(retriever.encode(query).tolist())\n",
    "\n",
    "    # query the lancedb table\n",
    "    query_builder = lancedb_table.search(query_vector, vector_column_name=\"vector\")\n",
    "    if named_entities:\n",
    "        # applying named entity filter if something was returned\n",
    "        where_clause = f\"array_length(array_intersect({named_entities}, named_entities)) > 0\"\n",
    "        query_builder = query_builder.where(where_clause, prefilter=prefilter)\n",
    "    result = (\n",
    "        query_builder.select([\"title\", \"url\", \"named_entities\"])  # what to return\n",
    "        .limit(top_k)\n",
    "        .to_list()\n",
    "    )\n",
    "    # could rerank results here\n",
    "    return {\"Query\": query, \"Query Entities\": named_entities, \"Result\": result}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "189c3647",
   "metadata": {},
   "source": [
    "# Execute some queries\n",
    "\n",
    "We can now run a few queries against what's in lancedb."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "387d2fd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing node: NER_model_id.\n",
      "Finished debugging node: NER_model_id in 179μs. Status: Success.\n",
      "Executing node: model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stefankrawczyk/.pyenv/versions/3.10.4/envs/ner-example-py310/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished debugging node: model in 584ms. Status: Success.\n",
      "Executing node: tokenizer.\n",
      "Finished debugging node: tokenizer in 152ms. Status: Success.\n",
      "Executing node: device.\n",
      "Finished debugging node: device in 42μs. Status: Success.\n",
      "Executing node: ner_pipeline.\n",
      "Finished debugging node: ner_pipeline in 2.02ms. Status: Success.\n",
      "Executing node: named_entities.\n",
      "Finished debugging node: named_entities in 52.3ms. Status: Success.\n",
      "Executing node: retriever.\n",
      "Finished debugging node: retriever in 1.75s. Status: Success.\n",
      "Executing node: lancedb_table.\n",
      "Finished debugging node: lancedb_table in 429μs. Status: Success.\n",
      "Executing node: lancedb_result.\n",
      "Finished debugging node: lancedb_result in 87.1ms. Status: Success.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'lancedb_result': {'Query': 'What is the future of autonomous vehicles?',\n",
       "  'Query Entities': [],\n",
       "  'Result': [{'title': 'Automated large scale data generation for autonomous vehicle',\n",
       "    'url': 'https://medium.com/mars-auto/automated-large-scale-data-generation-for-autonomous-vehicle-59de8b26357e',\n",
       "    'named_entities': ['Busan', 'Seoul'],\n",
       "    '_distance': 1.1834863424301147},\n",
       "   {'title': 'How to Scale Up Tech Solutions and Amplify Their Sustainability Impacts',\n",
       "    'url': 'https://medium.com/ksapa/how-to-scale-up-tech-solutions-and-amplify-their-sustainability-impacts-5124b192294d',\n",
       "    'named_entities': ['Augmented Reality',\n",
       "     'AI',\n",
       "     'Solutions',\n",
       "     'IoT',\n",
       "     'Machine Learning',\n",
       "     'Internet of Things',\n",
       "     'Virtual Reality',\n",
       "     'Global Goals',\n",
       "     'AR'],\n",
       "    '_distance': 1.3800410032272339},\n",
       "   {'title': 'Tech vs. Regulators: A Case In Point',\n",
       "    'url': 'https://medium.com/@nimishaagr/tech-vs-regulators-a-case-in-point-3959b8c81d27',\n",
       "    'named_entities': ['UK',\n",
       "     'New Delhi',\n",
       "     'Lianhao Qu',\n",
       "     'Chinese',\n",
       "     'Hong Kong',\n",
       "     'Unsplash',\n",
       "     'Comparitech',\n",
       "     'Christian Lange',\n",
       "     'In Point'],\n",
       "    '_distance': 1.5780136585235596},\n",
       "   {'title': 'Not the neighborhood he left: Biden’s international challenge',\n",
       "    'url': 'https://medium.com/@info-63603/not-the-neighborhood-he-left-bidens-international-challenge-d023f7ed26d0',\n",
       "    'named_entities': ['United States',\n",
       "     'Joe Biden',\n",
       "     'Russia',\n",
       "     'American',\n",
       "     'Trump',\n",
       "     'Biden',\n",
       "     'China'],\n",
       "    '_distance': 1.5939494371414185},\n",
       "   {'title': '⚡ Mega US Solar project announced, UK readies for 12GW Renewables Auction, Carbon Pricing as a Motivator, and Blockchain — Climate friend or foe?',\n",
       "    'url': 'https://medium.com/the-carbon-cut/mega-us-solar-project-announced-uk-readies-for-12gw-renewables-auction-carbon-pricing-as-a-24f026c13a1e',\n",
       "    'named_entities': ['UK',\n",
       "     'Mega US Solar',\n",
       "     'America',\n",
       "     'Carbon Pricing',\n",
       "     'Carbon',\n",
       "     'Blockchain',\n",
       "     'U. S',\n",
       "     'Jakub Rzeplinski',\n",
       "     'Goldman Sachs',\n",
       "     'Cut',\n",
       "     'Energy Ministry',\n",
       "     'Power',\n",
       "     'Energy Industry',\n",
       "     'European Environment Agency',\n",
       "     'American Wind Energy Association',\n",
       "     'Contracts for Difference Round Four',\n",
       "     'Auction',\n",
       "     'Solar Project'],\n",
       "    '_distance': 1.6431432962417603},\n",
       "   {'title': 'Article Comment: We fight networks by realizing we are networks.',\n",
       "    'url': 'https://medium.com/greyswandigital/michael-in-times-of-crisis-when-verification-of-information-sources-may-be-incomplete-we-also-f8788439e9c4',\n",
       "    'named_entities': ['Michael'],\n",
       "    '_distance': 1.6467256546020508},\n",
       "   {'title': 'The EU-Asia Connectivity Strategy',\n",
       "    'url': 'https://medium.com/freeman-spogli-institute-for-international-studies/the-eu-asia-connectivity-strategy-8ce605a5d8a4',\n",
       "    'named_entities': ['East Asia',\n",
       "     'EU',\n",
       "     'Stanford University',\n",
       "     'Europe',\n",
       "     'Central',\n",
       "     'BRI',\n",
       "     'Justin Tomczyk',\n",
       "     'Economic Policy Research Center',\n",
       "     'East European',\n",
       "     'Belt',\n",
       "     'and Road',\n",
       "     'Southeast',\n",
       "     'Eurasian Studies',\n",
       "     'European Commission',\n",
       "     'Initiative',\n",
       "     'FSI Global',\n",
       "     'Asia Connectivity Strategy',\n",
       "     'Russian'],\n",
       "    '_distance': 1.6500868797302246},\n",
       "   {'title': 'Police Brutality During COVID19 Continues Unabated',\n",
       "    'url': 'https://extremearturo.medium.com/police-brutality-during-covid19-continues-unabated-10d427c4225c',\n",
       "    'named_entities': ['America',\n",
       "     'United Nations',\n",
       "     'Asia',\n",
       "     'United States',\n",
       "     'Americans',\n",
       "     'Latin America',\n",
       "     'COVID19',\n",
       "     'UN',\n",
       "     'Fibonacci Blue',\n",
       "     'Africa',\n",
       "     'Creative Commons'],\n",
       "    '_distance': 1.6865490674972534},\n",
       "   {'title': 'Poeple’s of Israel are protesting against Netanyahu',\n",
       "    'url': 'https://medium.com/@bazranorotlo/poeples-of-israel-are-protesting-against-netanyahu-cc31b4a04a8f',\n",
       "    'named_entities': ['Benny Gantz',\n",
       "     'Blue and White',\n",
       "     'Poeple ’ s',\n",
       "     'Knesset',\n",
       "     'COVID - 19',\n",
       "     'Saudi Arabia',\n",
       "     'United Arab Emirates',\n",
       "     'Gantz',\n",
       "     'Bahrain',\n",
       "     'Likud',\n",
       "     'Israel',\n",
       "     'Netanyahu'],\n",
       "    '_distance': 1.6886379718780518},\n",
       "   {'title': 'Weekly Digest: New Livestream, #10YearsChallenge and an Ultimate Guide to our Customer Support',\n",
       "    'url': 'https://medium.com/crypterium/weeklydigest-livestream-10yearchallenge-support-2733b5f9f944',\n",
       "    'named_entities': ['Lisbon',\n",
       "     'Rafael Carrascosa',\n",
       "     'Facebook',\n",
       "     'Moscow',\n",
       "     'Rafael',\n",
       "     'Crypterium',\n",
       "     'Hong Kong',\n",
       "     'Youtube',\n",
       "     'Global Partnerships',\n",
       "     'London',\n",
       "     'GMT'],\n",
       "    '_distance': 1.6984761953353882}]}}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dr_query = (\n",
    "    driver.Builder()\n",
    "    .with_config({})\n",
    "    .with_modules(ner_module)\n",
    "    .with_adapters(lifecycle.PrintLn())\n",
    "    .build()\n",
    ")\n",
    "dr_query.execute([\"lancedb_result\"], \n",
    "                 inputs={\"table_name\": table_name, \n",
    "                         \"query\": \"What is the future of autonomous vehicles?\",\n",
    "                         \"db_client\": db_client\n",
    "                        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "af54bccd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing node: NER_model_id.\n",
      "Finished debugging node: NER_model_id in 309μs. Status: Success.\n",
      "Executing node: model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished debugging node: model in 405ms. Status: Success.\n",
      "Executing node: tokenizer.\n",
      "Finished debugging node: tokenizer in 133ms. Status: Success.\n",
      "Executing node: device.\n",
      "Finished debugging node: device in 39.3μs. Status: Success.\n",
      "Executing node: ner_pipeline.\n",
      "Finished debugging node: ner_pipeline in 1.9ms. Status: Success.\n",
      "Executing node: named_entities.\n",
      "Finished debugging node: named_entities in 26.4ms. Status: Success.\n",
      "Executing node: retriever.\n",
      "Finished debugging node: retriever in 1.28s. Status: Success.\n",
      "Executing node: lancedb_table.\n",
      "Finished debugging node: lancedb_table in 371μs. Status: Success.\n",
      "Executing node: lancedb_result.\n",
      "Finished debugging node: lancedb_result in 64.9ms. Status: Success.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'lancedb_result': {'Query': 'Who is Joe Biden?',\n",
       "  'Query Entities': ['Joe Biden'],\n",
       "  'Result': [{'title': 'Not the neighborhood he left: Biden’s international challenge',\n",
       "    'url': 'https://medium.com/@info-63603/not-the-neighborhood-he-left-bidens-international-challenge-d023f7ed26d0',\n",
       "    'named_entities': ['United States',\n",
       "     'Joe Biden',\n",
       "     'Russia',\n",
       "     'American',\n",
       "     'Trump',\n",
       "     'Biden',\n",
       "     'China'],\n",
       "    '_distance': 0.9555794596672058}]}}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dr_query.execute(\n",
    "    [\"lancedb_result\"], \n",
    "    inputs={\n",
    "        \"table_name\": table_name, \n",
    "        \"query\": \"Who is Joe Biden?\",\n",
    "        \"db_client\": db_client\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "e8fe0207",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing node: NER_model_id.\n",
      "Finished debugging node: NER_model_id in 99.9μs. Status: Success.\n",
      "Executing node: model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished debugging node: model in 273ms. Status: Success.\n",
      "Executing node: tokenizer.\n",
      "Finished debugging node: tokenizer in 130ms. Status: Success.\n",
      "Executing node: device.\n",
      "Finished debugging node: device in 25.3μs. Status: Success.\n",
      "Executing node: ner_pipeline.\n",
      "Finished debugging node: ner_pipeline in 1.78ms. Status: Success.\n",
      "Executing node: named_entities.\n",
      "Finished debugging node: named_entities in 27.5ms. Status: Success.\n",
      "Executing node: retriever.\n",
      "Finished debugging node: retriever in 1.15s. Status: Success.\n",
      "Executing node: lancedb_table.\n",
      "Finished debugging node: lancedb_table in 149μs. Status: Success.\n",
      "Executing node: lancedb_result.\n",
      "Finished debugging node: lancedb_result in 27.1ms. Status: Success.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'lancedb_result': {'Query': 'How Data is changing the world?',\n",
       "  'Query Entities': ['Data'],\n",
       "  'Result': []}}"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dr_query.execute(\n",
    "    [\"lancedb_result\"], \n",
    "    inputs={\n",
    "        \"table_name\": table_name, \n",
    "        \"query\": \"How Data is changing the world?\",\n",
    "        \"db_client\": db_client\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "099a0381",
   "metadata": {},
   "source": [
    "# Connect with tracking & telemetry\n",
    "To gain more visibility into execution we can connect with the [Hamilton UI](https://blog.dagworks.io/p/hamilton-ui-streamlining-metadata?r=2cg5z1&utm_campaign=post&utm_medium=web). The following code assumes you have things running locally already. If you're looking for the SaaS version, signup for the free tier at [DAGWorks Inc](https://www.dagworks.io/hamliton)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c21dd0d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Capturing execution run. Results can be found at http://localhost:8242/dashboard/project/41/runs/51\n",
      "\n",
      "/Users/stefankrawczyk/.pyenv/versions/3.10.4/envs/ner-example-py310/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n",
      "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "\n",
      "Captured execution run. Results can be found at http://localhost:8242/dashboard/project/41/runs/51\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'lancedb_result': {'Query': 'Who is Joe Biden?',\n",
       "  'Query Entities': ['Joe Biden'],\n",
       "  'Result': [{'title': 'Not the neighborhood he left: Biden’s international challenge',\n",
       "    'url': 'https://medium.com/@info-63603/not-the-neighborhood-he-left-bidens-international-challenge-d023f7ed26d0',\n",
       "    'named_entities': ['United States',\n",
       "     'Joe Biden',\n",
       "     'Russia',\n",
       "     'American',\n",
       "     'Trump',\n",
       "     'Biden',\n",
       "     'China'],\n",
       "    '_distance': 0.9555794596672058}]}}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from hamilton_sdk import adapters\n",
    "from hamilton import driver\n",
    "import uuid\n",
    "\n",
    "import lancedb\n",
    "table_name = \"medium_docs\"\n",
    "db_client =  lancedb.connect(\"./.lancedb\")\n",
    "RUN_ID = str(uuid.uuid4())\n",
    "\n",
    "tracker = adapters.HamiltonTracker(\n",
    "    project_id=41,                  # <--- modify this \n",
    "    username=\"elijah@dagworks.io\",  # <--- modify this \n",
    "    dag_name=\"ner-lancedb-pipeline\",\n",
    "    tags={\"context\": \"querying\",\n",
    "          \"team\": \"MY_TEAM\",\n",
    "          \"run_id\":  RUN_ID,\n",
    "          \"version\": \"1\"},\n",
    ")\n",
    "dr_query = (\n",
    "    driver.Builder()\n",
    "    .with_config({})\n",
    "    .with_modules(ner_module)\n",
    "    .with_adapters(tracker)\n",
    "    .build()\n",
    ")\n",
    "dr_query.execute(\n",
    "    [\"lancedb_result\"], \n",
    "    inputs={\n",
    "        \"table_name\": table_name, \n",
    "        \"query\": \"Who is Joe Biden?\",\n",
    "        \"db_client\": db_client\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9d7cd2f",
   "metadata": {},
   "source": [
    "# Summary\n",
    "In this notebook we:\n",
    "\n",
    "1. incrementally created a pipeline to process medium articles\n",
    "2. the pipeline extracted named entities from the articles\n",
    "3. the pipeline created vectors embeddings from text\n",
    "4. we pushed all the data into lanceDB to then query against\n",
    "\n",
    "# Next steps to combine with RAG\n",
    "We now have a database that can query over medium articles via cosine similarity, as well as\n",
    "using extra metadata, in this case named entities referenced in the text, extracted to help us filter results.\n",
    "\n",
    "With this general blueprint, you can then play around with and modify what context you would\n",
    "retrieve given a user query to then populate a prompt with to send to an LLM.\n",
    "\n",
    "For example, we could take the URLs returned and load the document that way, or \n",
    "adjust what is stored in lancedb and return text stored there, etc.  If you'd \n",
    "like to build a conversational agent, we refer you to Hamilton's sister framework\n",
    "[Burr](https://github.com/apache/burr) that can help you build, curate,\n",
    "and debug your application.\n",
    "\n",
    "\n",
    "# Extensions\n",
    "There's many ways to extend this pipeline. Here are a few ideas:\n",
    "\n",
    "1. Use a different NER model.\n",
    "2. Use a different embedding model.\n",
    "3. Use a different database.\n",
    "4. Use more data to filter the results by, e.g. ACLs if applicable.\n",
    "5. Use query expansion to improve the results by expanding the extracted entities from the query.\n",
    "6. Use a re-ranking algorithm to rank the results.\n",
    "7. Work on document chunking to optimize for your particular RAG use case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b064c176",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
